├── .gitignore ├── LICENSE ├── README.md ├── assets └── logo-final.png ├── examples ├── pipelines │ ├── combined_filter.py │ ├── embed_datasets.py │ ├── embed_datasets.sh │ ├── filter_datasets.sh │ ├── score_complexity.py │ ├── score_complexity_dataset.py │ ├── score_complexity_dataset.sh │ ├── score_quality.py │ ├── score_quality_dataset.sh │ ├── score_vllm.py │ └── utils.py └── train │ ├── dpo.sh │ ├── sft.sh │ └── train_scorers.sh ├── requirements.txt ├── sample_little.py ├── setup.py └── src └── deita ├── __init__.py ├── alignment ├── __init__.py ├── constants.py ├── conversation.py ├── dpo_train.py ├── flash_attn │ ├── bloom_flash_attention.py │ └── triton_flash_attention.py ├── train.py └── train_scorers.py ├── data └── sample_ultrafeedback.py ├── ds_configs ├── deepspeed_config_zero2_no_offload.json ├── deepspped_llama_x.json └── stage3_no_offloading_accelerate.json ├── pipeline ├── __init__.py ├── base.py ├── embed_pipeline.py ├── filter_pipeline.py ├── score_pipeline.py └── utils.py └── selection ├── __init__.py ├── embedder ├── __init__.py ├── base.py ├── clm_embedder.py ├── conversation.py └── utils.py ├── filter ├── __init__.py ├── base.py ├── combined_filter.py └── utils.py └── scorer ├── __init__.py ├── base.py ├── llama_scorer.py └── mistral_scorer.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *egg-info* 3 | *build* 4 | *dist* 5 | *egg-info 6 | .history 7 | data/ 8 | outputs/ 9 | logs/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deita 2 | 3 |

4 | 5 |

6 | 7 | 8 |

9 | 🤗 HF Repo    10 | 📄 Paper    11 | 📚 6K Data    12 | 📚 10K Data 13 |

14 | 15 | 16 | Welcome to Deita (**D**ata-**E**fficient **I**nstruction **T**uning for **A**lignment) Project! 17 | 18 | We will continue to update, please stay tuned! 19 | 20 | 21 | ## What is Deita? 22 | Deita is an open-sourced project designed to facilitate **Automatic Data Selection** for instruction tuning in Large Language Models (LLMs). 23 | 24 | It includes: 25 | - **Open-sourced Toolkits** for automatic data selection in instruction tuning 26 | - **Deita Datasets**: A series of extremely *lightweight*, high-quality alignment SFT data. We release 6k-sized and 10k-sized datasets in the first release 27 | - **Deita Models**: A series of powerful models on par with SOTA chat LLMs with an extremely efficient instruction tuning Process. Deita models can be obained by training with 10x less instruction tuning data compared with other SOTA LLMs 28 | 29 | ## News 30 | - :fire: [03/2024] Our datasets have been used by Huggingface to creat the [Zephyr Gemma Model](https://huggingface.co/collections/HuggingFaceH4/zephyr-7b-gemma-65e1fd82d26b426e3e63d956). 31 | - 📄 [01/2024] Deita paper [What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning](https://arxiv.org/abs/2312.15685) has been accepted by ICLR2024! 32 | - :fire: [01/2024] [Deita pipelines](#deita-pipelines) have been released! With one line code and configurations, a high-quality data subset for alignment can be selected. 33 | - 📚 [01/2024] Our scorer datasets [deita-complexity-scorer-data](https://huggingface.co/datasets/hkust-nlp/deita-complexity-scorer-data) and [deita-quality-scorer-data](https://huggingface.co/datasets/hkust-nlp/deita-quality-scorer-data) have been released. 34 | - :fire: [12/2023] We release the first collection of the Deita resources [here](https://huggingface.co/collections/hkust-nlp/deita-6569c198c174808d94cf5bd4), which include a series of extremely lightweight, effective sft datasets, the data complexity/quality scorer models, as well as the resulted deita chat models. 35 | 36 | ## Performance 37 | :bell: Still curious about how far a small amount of high-quality data can lead LLMs? 38 | 39 | Deita may provide an answer for you: 40 | 41 | **🔦 Highlights** 42 | | Model | Align | Data Size | MT-Bench | AlpacaEval(%) | 43 | |------------------------------------------------|--------------|------------|----------|---------------| 44 | | Zephyr-7B-sft | SFT | 200K | 5.32 | 75.12 | 45 | | $\text{Zephyr-7B-}\beta$ | SFT + DPO | 200K SFT + 60K DPO | 7.34 | 90.60 | 46 | | OpenChat-3.5 | C-RLFT | >> 70K C-RLFT | 7.81 | 88.51 | 47 | | Starling-7B | C-RLFT + APA | >> 70K C-RLFT + 183K APA | 8.09 | 91.99 | 48 | | Tulu-2-13B | SFT | 326K | 6.70 | 78.90 | 49 | | Tulu-2-13B+DPO | SFT + DPO | 326K SFT + 60K DPO | 7.00 | 89.50 | 50 | | LLaMA2-13B-Chat | SFT + PPO | -- | 6.65 | 81.09 | 51 | | WizardLM-13B-v1.2 | SFT | >70K | 7.09 | 89.17 | 52 | | Vicuna-13B-v1.5 | SFT | >125K | 6.57 | 78.80 | 53 | | DEITA-7B-v1.0 (6K) | SFT | 6K | 7.22 | 80.78 | 54 | | DEITA-7B-v1.0-sft | SFT | 10K | 7.32 | 81.67 | 55 | | DEITA-7B-v1.0 | SFT + DPO | 6K SFT + 10K DPO | 7.55 | 90.06 | 56 | 57 | DEITA models are based on Mistral-7B-v0.1. :fire: 58 | 59 | Please refer to [this table](#chart\_with\_upwards\_trend-full-evaluations) for full evaluations including Open LLM Leaderboard as well, which includes DEITA models with LLaMA base models and comparisons with other data selection approaches. 60 | 61 | 62 | 63 | ## :chart_with_upwards_trend: Full Evaluations 64 | 65 |
66 | See full evaluations 67 | 68 | | Model | Align | Data Size | MT-Bench | AlpacaEval(%) | OpenLLM (Avg.) | 69 | |------------------------------------------------|-----------|------------|----------|---------------|----------------| 70 | | **Proprietary Models** | | | | | | 71 | | GPT-4-Turbo | ? | -- | 9.32 | 97.70 | -- | 72 | | GPT-4 | SFT + PPO | -- | 8.99 | 95.03 | -- | 73 | | Claude-2 | SFT + PPO | -- | 8.06 | 91.36 | -- | 74 | | GPT-3.5-turbo | SFT + PPO | -- | 7.94 | 89.37 | -- | 75 | | **Open-sourced Models based on LLaMA-1-13B** | | | | | | 76 | | LIMA | SFT | 1K SFT | 4.29 | 41.98 | 59.82 | 77 | | WizardLM-13B | SFT | 70K SFT | 6.35 | 75.31 | 58.96 | 78 | | Vicuna-13B-v1.3 | SFT | 125K SFT | 6.39 | 82.11 | 60.01 | 79 | | Random | SFT | 10K SFT | 6.03 | 71.52 | 60.14 | 80 | | DEITA-LLaMA1-13B-v1.0-sft | SFT | 10K SFT | 6.60 | 78.01 | 64.27 | 81 | | **Open-sourced Models based on LLaMA-2-13B** | | | | | | 82 | | Tulu-2-13B | SFT | 326K SFT | 6.70 | 78.90 | -- | 83 | | Tulu-2-13B+DPO | SFT + DPO | 326K SFT + 60K DPO | 7.00 | 89.50 | -- | 84 | | LLaMA2-13B-Chat | SFT + PPO | -- | 6.65 | 81.09 | -- | 85 | | WizardLM-13B-v1.2 | SFT | >70K SFT | 7.09 | 89.17 | -- | 86 | | Vicuna-13B-v1.5 | SFT | 125K SFT | 6.57 | 78.80 | 61.63 | 87 | | Random | SFT | 10K SFT | 5.78 | 65.19 | 61.32 | 88 | | DEITA-LLaMA2-13B-v1.0-sft | SFT | 10K SFT | 6.79 | 81.09 | 62.71 | 89 | | **Open-sourced Models based on Mistral-7B** | | | | | | 90 | | Mistral-7B-Instruct-v0.1 | -- | -- | 6.84 | 69.65 | 60.45 | 91 | | Zephyr-7B-sft | SFT | 200K SFT | 5.32 | 75.12 | 60.93 | 92 | | $\text{Zephyr-7B-}\beta$ | SFT + DPO | 200K SFT + 60K DPO | 7.34 | 90.60 | 66.36 | 93 | | OpenChat-3.5 | C-RLFT | >> 70K C-RLFT | 7.81 | 88.51 | -- | 94 | | Starling-7B | C-RLFT + APA | >>70K C-RLFT + 183K APA | 8.09 | 91.99 | -- | 95 | | Random | SFT | 10K SFT | 5.89 | 56.90 | 61.72 | 96 | | DEITA-7B-v1.0-sft (6K) | SFT | 6K SFT | 7.22 | 80.78 | 64.94 | 97 | | DEITA-7B-v1.0-sft (10K) | SFT | 10K SFT | 7.32 | 81.67 | 64.00 | 98 | | DEITA-7B-v1.0 | SFT + DPO | 6K SFT + 10K DPO | 7.55 | 90.06 | 69.86 | 99 | 100 | 101 |
102 | 103 | ## :rocket: Deita Resources 104 | 105 | | Resource | Link | License | 106 | |------------------------------------------------|-----------|------------| 107 | | **Deita Datasets** | | | 108 | | deita-6k-v0 | [:hugs: HF Repo](https://huggingface.co/datasets/hkust-nlp/deita-6k-v0) | [MIT License](https://opensource.org/license/mit/) | 109 | | deita-10k-v0 | [:hugs: HF Repo](https://huggingface.co/datasets/hkust-nlp/deita-10k-v0) | [MIT License](https://opensource.org/license/mit/) | 110 | | deita-complexity-scorer-data | [:hugs: HF Repo](https://huggingface.co/datasets/hkust-nlp/deita-complexity-scorer-data) | [MIT License](https://opensource.org/license/mit/) | 111 | | deita-quality-scorer-data | [:hugs: HF Repo](https://huggingface.co/datasets/hkust-nlp/deita-quality-scorer-data) | [MIT License](https://opensource.org/license/mit/) | 112 | | deita-redundant-pool (100K) | [:hugs: HF Repo](https://huggingface.co/datasets/hkust-nlp/deita-redundant-pool-data) | [MIT License](https://opensource.org/license/mit/) | 113 | | deita-sota-pool (300K) | [:hugs: HF Repo](https://huggingface.co/datasets/AndrewZeng/deita_sota_pool) | [MIT License](https://opensource.org/license/mit/) | 114 | | **Scorers** | | | 115 | | deita-complexity-scorer | [:hugs: HF Repo](https://huggingface.co/hkust-nlp/deita-complexity-scorer) | [LLaMA License](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)| 116 | | deita-quality-scorer | [:hugs: HF Repo](https://huggingface.co/hkust-nlp/deita-quality-scorer) | [LLaMA License](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)| 117 | | **Deita Models** | | | 118 | | DEITA-7B-v1.0-sft | [:hugs: HF Repo](https://huggingface.co/hkust-nlp/deita-7b-v1.0-sft) | [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) | 119 | | DEITA-7B-v1.0 | [:hugs: HF Repo](https://huggingface.co/hkust-nlp/deita-7B-v1.0) | [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) | 120 | | DEITA-LLaMA2-13B-v1.0-sft | [:hugs: HF Repo](https://huggingface.co/hkust-nlp/deita-llama2-13b-v1.0-sft) | [LLaMA 2 License](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) | 121 | | DEITA-LLaMA1-13B-v1.0-sft | [:hugs: HF Repo](https://huggingface.co/hkust-nlp/deita-llama1-13b-v1.0-sft) | [LLaMA License](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) | 122 | 123 | ## :running_man: How to start? 124 | 125 | 126 | ### Installation 127 | ```bash 128 | git clone https://github.com/hkust-nlp/deita.git 129 | cd deita 130 | pip install -e . 131 | ``` 132 | 133 | ### Data Sample Scoring 134 | 135 | If you wish to assess the **quality** of a response for a single sample, you can follow these steps: 136 | ```python 137 | from deita.selection.scorer import Llama_Scorer 138 | 139 | model_name_or_path = "hkust-nlp/deita-quality-scorer" 140 | 141 | scorer = Llama_Scorer(model_name_or_path) 142 | 143 | # example input 144 | input_text = "word to describe UI with helpful tooltips" # Example Input 145 | output_text = "User-friendly or intuitive UI" # Example Output 146 | quality_score = scorer.infer_quality(input_text, output_text) 147 | 148 | print(quality_score) 149 | # 2.0230105920381902 150 | ``` 151 | 152 | Deita also supports VLLM for faster inference. If you want to use VLLM for inference, 153 | 154 | ```bash 155 | pip install vllm 156 | ``` 157 | 158 | And set ```is_vllm = True``` when initilizing scorer 159 | 160 | ```python 161 | scorer = Llama_Scorer(model_name_or_path, is_vllm = True) 162 | ``` 163 | 164 | To assess other dimensions of data samples, please refer to the ```examples/scoring``` 165 | 166 | ### Deita Pipelines 167 | 168 | You can use deita pipelines to perform a variety of operations on the dataset with only one line code and configurations. 169 | 170 | - **Dataset Scoring** 171 | 172 | ```python 173 | from deita.pipeline import Pipeline 174 | 175 | pipeline = Pipeline("score_pipeline", 176 | data_path = args.data_path, # json file with sharegpt format 177 | scorer = args.scorer, # [mistral, llama] 178 | scorer_name_or_path = args.scorer_name_or_path, # scorer name or path e.g. hkust-nlp/deita-complexity-scorer 179 | is_vllm = args.is_vllm, # launch with vllm [True, False] 180 | score_type = args.score_type, # [complexity, quality] 181 | output_path = args.output_path) # output path (json format) 182 | 183 | pipeline.run() 184 | ``` 185 | 186 | - **Get Embeddings** 187 | 188 | We use Huggingface Accelerate to enhance efficiency: 189 | 190 | ```python 191 | from deita.pipeline import Pipeline 192 | 193 | embed_pipeline = Pipeline("embed_pipeline", 194 | data_path = args.data_path, # json file with sharegpt format 195 | output_path = args.output_path, # output path (pickle format) 196 | model_name_or_path = args.model_name_or_path, # model name or path e.g. mistralai/Mistral-7B-v0.1 197 | max_length = args.max_length, 198 | use_flash_attention = args.use_flash_attention, 199 | batch_size_per_device = args.batch_size_per_device, 200 | conv_template = args.conv_template, 201 | only_answer = args.only_answer, 202 | random_shuffle = args.random_shuffle, 203 | bfloat16 = True 204 | ) 205 | 206 | embed_pipeline.run() 207 | ``` 208 | 209 | ```bash 210 | CUDA_VISIBLE_DEVICES=$GPUIDX accelerate launch \ 211 | --mixed_precision bf16 \ 212 | --num_processes $NUMPROCESS \ 213 | --num_machines 1 \ 214 | examples/pipelines/embed_datasets.py \ 215 | --use_flash_attention true \ 216 | --data_path $DATAPATH \ 217 | --output_path $OUTPUTPATH \ 218 | --batch_size_per_device $BSZ 219 | ``` 220 | 221 | - **Score-first, Diversity-aware Selection** 222 | 223 | ```python 224 | from deita.pipeline import Pipeline 225 | 226 | filter_pipeline = Pipeline("filter_pipeline", 227 | data_path = args.data_path, # json file with sharegpt format 228 | other_data_path = args.other_data_path, # embedding file path (pickle format) 229 | threshold = args.threshold, # filter threshold default: 0.9 230 | data_size = args.data_size, # size of selected data 231 | chunk_size = args.chunk_size, # used for more efficient GPU computing default: 100000 232 | sort_key = args.sort_key, # default: "complexity_scores,quality_scores" 233 | output_path = args.output_path, # json format output path 234 | distance_metric = args.distance_metric, # default: cosine 235 | embedding_field = args.embedding_field, # default: embedding 236 | is_compression = args.is_compression, # default: False 237 | device = args.device # GPU IDX, default: 0 238 | ) 239 | 240 | filter_pipeline.run() 241 | ``` 242 | 243 | You can refer to ```examples/pipelines``` for more details. A doc will also be coming soon. 244 | 245 | ### SFT Training 246 | Please refer to ```examples/train/sft.sh``` 247 | ```bash 248 | deepspeed --include localhost:${DEVICES} --master_port 29501 src/deita/alignment/train.py \ 249 | --model_name_or_path ${MODELPATH} \ 250 | --data_path ${DATAPATH} \ 251 | --output_dir ${OUTPUTPATH}/${RUNNAME} \ 252 | --num_train_epochs 6 \ 253 | --per_device_train_batch_size ${BSZPERDEV} \ 254 | --per_device_eval_batch_size 1 \ 255 | --gradient_accumulation_steps ${GRADACC} \ 256 | --eval_steps 50 \ 257 | --save_strategy "no" \ 258 | --save_steps 100 \ 259 | --save_total_limit 10 \ 260 | --learning_rate 2e-5 \ 261 | --warmup_ratio 0.1 \ 262 | --lr_scheduler_type "cosine" \ 263 | --logging_steps 1 \ 264 | --do_eval False \ 265 | --evaluation_strategy "no" \ 266 | --model_max_length 2048 \ 267 | --lazy_preprocess True \ 268 | --conv_template "vicuna_v1.1" \ 269 | --mask_user True \ 270 | --report_to "wandb" \ 271 | --run_name ${RUNNAME} \ 272 | --bf16 True \ 273 | --deepspeed src/deita/ds_configs/deepspeed_config_zero2_no_offload.json 274 | ``` 275 | 276 | ### DPO Training 277 | Please refer to ```examples/train/dpo.sh``` 278 | ```bash 279 | deepspeed --include localhost:${DEVICES} --master_port 29502 src/deita/alignment/dpo_train.py \ 280 | --model_name_or_path ${MODELPATH} \ 281 | --json_path ${JSONPATH} \ 282 | --data_split ${DATASPLIT} \ 283 | --output_dir ${OUTPUTPATH}/${RUNNAME} \ 284 | --num_train_epochs ${DPOEPOCH} \ 285 | --beta 0.1 \ 286 | --per_device_train_batch_size ${BSZPERDEV} \ 287 | --per_device_eval_batch_size 1 \ 288 | --gradient_accumulation_steps ${GRADACC} \ 289 | --save_global_steps False \ 290 | --eval_steps 50 \ 291 | --save_strategy "no" \ 292 | --save_steps 500 \ 293 | --save_total_limit 1 \ 294 | --learning_rate 5e-7 \ 295 | --warmup_ratio 0.1 \ 296 | --lr_scheduler_type "linear" \ 297 | --logging_steps 1 \ 298 | --do_eval False \ 299 | --evaluation_strategy "no" \ 300 | --model_max_length 2048 \ 301 | --conv_template "vicuna_v1.1" \ 302 | --report_to "wandb" \ 303 | --run_name ${RUNNAME} \ 304 | --bf16 True \ 305 | --gradient_checkpointing True \ 306 | --deepspeed src/deita/ds_configs/stage3_no_offloading_accelerate.json 307 | ``` 308 | 309 | ### Evaluation 310 | - For MT-Bench, please refer to [MT-Bench](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge) 311 | - For AlpacaEval, please refer to [alpaca_eval](https://github.com/tatsu-lab/alpaca_eval) 312 | - For Open LLM Benchmark, please refer to [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/master) and follow settings on [HuggingFaceH4/open_llm_leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard) 313 | 314 | ## :muscle: What's more? 315 | 316 | This is the preview version of Deita project. We will continue to update including 317 | 318 | - [ ] Release data selection pipeline with efficient implementation 319 | - [ ] More automatic data selection strategies 320 | - [ ] CLI-Interface Supported 321 | - [ ] Online Demo 322 | 323 | ## Citation 324 | If you find the content of this project helpful, please cite our paper as follows: 325 | 326 | ``` 327 | @inproceedings{ 328 | liu2024what, 329 | title={What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning}, 330 | author={Wei Liu and Weihao Zeng and Keqing He and Yong Jiang and Junxian He}, 331 | booktitle={The Twelfth International Conference on Learning Representations}, 332 | year={2024}, 333 | url={https://openreview.net/forum?id=BTKAeLqLMw} 334 | } 335 | ``` 336 | 337 | ## Acknowledgement 338 | For training code, we use the code template of [fastchat](https://github.com/lm-sys/FastChat). 339 | -------------------------------------------------------------------------------- /assets/logo-final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/deita/b279f2c329b403d2612a61e270c8d2a2eeaed6f4/assets/logo-final.png -------------------------------------------------------------------------------- /examples/pipelines/combined_filter.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | from deita.pipeline import Pipeline 4 | 5 | logging.basicConfig(level=logging.INFO) 6 | logger = logging.getLogger(__name__) 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--data_path", type=str, default=None) 10 | parser.add_argument("--other_data_path", type=str, default=None) 11 | parser.add_argument("--threshold", type=float, default=0.9) 12 | parser.add_argument("--data_size", type=int, default=10) 13 | parser.add_argument("--chunk_size", type=int, default=100000) 14 | parser.add_argument("--sort_key", type=str, default="complexity_scores,quality_scores") 15 | parser.add_argument("--output_path", type=str, default=None) 16 | parser.add_argument("--distance_metric", type=str, default="cosine") 17 | parser.add_argument("--embedding_field", type=str, default="embedding") 18 | parser.add_argument("--is_compression", type=bool, default=False) 19 | parser.add_argument("--device", type=int, default="0") 20 | 21 | args = parser.parse_args() 22 | 23 | filter_pipeline = Pipeline("filter_pipeline", 24 | data_path = args.data_path, # json file with sharegpt format 25 | other_data_path = args.other_data_path, # embedding file path (pickle format) 26 | threshold = args.threshold, # filter threshold default: 0.9 27 | data_size = args.data_size, # size of selected data 28 | chunk_size = args.chunk_size, # used for more efficient GPU computing default: 100000 29 | sort_key = args.sort_key, # default: "complexity_scores,quality_scores" 30 | output_path = args.output_path, # json format output path 31 | distance_metric = args.distance_metric, # default: cosine 32 | embedding_field = args.embedding_field, # default: embedding 33 | is_compression = args.is_compression, # default: False 34 | device = args.device # GPU IDX, default: 0 35 | ) 36 | 37 | filter_pipeline.run() -------------------------------------------------------------------------------- /examples/pipelines/embed_datasets.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | from deita.pipeline import Pipeline 4 | 5 | logging.basicConfig(level=logging.INFO) 6 | logger = logging.getLogger(__name__) 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--data_path", type=str, default=None) 10 | parser.add_argument("--output_path", type=str, default=None) 11 | parser.add_argument("--max_length", type=int, default=2048) 12 | parser.add_argument("--batch_size_per_device", type=int, default=4) 13 | parser.add_argument("--conv_template", type=str, default="vicuna_v1.1") 14 | parser.add_argument("--use_flash_attention", type=bool, default=False) 15 | parser.add_argument("--only_answer", type=bool, default=False) 16 | parser.add_argument("--random_shuffle", type=bool, default=False) 17 | parser.add_argument("--model_name_or_path", type=str, default="mistralai/Mistral-7B-v0.1") 18 | 19 | args = parser.parse_args() 20 | 21 | embed_pipeline = Pipeline("embed_pipeline", 22 | data_path = args.data_path, # json file with sharegpt format 23 | output_path = args.output_path, # output path (pickle format) 24 | model_name_or_path = args.model_name_or_path, # model name or path e.g. mistralai/Mistral-7B-v0.1 25 | max_length = args.max_length, 26 | use_flash_attention = args.use_flash_attention, 27 | batch_size_per_device = args.batch_size_per_device, 28 | conv_template = args.conv_template, 29 | only_answer = args.only_answer, 30 | random_shuffle = args.random_shuffle, 31 | bfloat16 = True 32 | ) 33 | 34 | embed_pipeline.run() -------------------------------------------------------------------------------- /examples/pipelines/embed_datasets.sh: -------------------------------------------------------------------------------- 1 | GPUIDX="0,1,2,3" 2 | NUMPROCESS=4 3 | DATAPATH="" 4 | BSZ=1 5 | OUTPUTPATH="" 6 | 7 | CUDA_VISIBLE_DEVICES=$GPUIDX accelerate launch \ 8 | --mixed_precision bf16 \ 9 | --num_processes $NUMPROCESS \ 10 | --num_machines 1 \ 11 | examples/pipelines/embed_datasets.py \ 12 | --use_flash_attention true \ 13 | --data_path $DATAPATH \ 14 | --output_path $OUTPUTPATH \ 15 | --batch_size_per_device $BSZ 16 | -------------------------------------------------------------------------------- /examples/pipelines/filter_datasets.sh: -------------------------------------------------------------------------------- 1 | GPUIDX="0,1,2,3" 2 | NUMGPUS=$(echo $GPUIDX | awk -F',' '{print NF}') 3 | DATAPATH="" 4 | OTHERDATA="" # PATH/TO/EMBEDDING_FILE 5 | OUTPUTPATH="" # PATH/TO/OUTPUTS 6 | THETA=0.9 7 | DATASIZE=10 8 | BSZ=1 9 | 10 | CUDA_VISIBLE_DEVICES=$GPUIDX python examples/pipelines/combined_filter.py \ 11 | --data_path $DATAPATH \ 12 | --other_data_path $OTHERDATA \ 13 | --output_path $OUTPUTPATH \ 14 | --threshold $THETA \ 15 | --data_size $DATASIZE \ 16 | --is_compression true \ 17 | --device 0 18 | -------------------------------------------------------------------------------- /examples/pipelines/score_complexity.py: -------------------------------------------------------------------------------- 1 | from deita.selection.scorer import Llama_Scorer 2 | 3 | model_name_or_path = "hkust-nlp/deita-complexity-scorer" 4 | 5 | scorer = Llama_Scorer(model_name_or_path) 6 | 7 | # example input 8 | input_text = "write a performance review for a junior data scientist" 9 | complexity_score = scorer.infer_complexity(input_text) 10 | 11 | print(complexity_score) -------------------------------------------------------------------------------- /examples/pipelines/score_complexity_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | from deita.pipeline import Pipeline 4 | 5 | logger = logging.getLogger(__name__) 6 | logger.info("Running score_pipeline") 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--data_path", type=str, default=None) 10 | parser.add_argument("--output_path", type=str, default=None) 11 | parser.add_argument("--scorer", type=str, default="llama") 12 | parser.add_argument("--scorer_name_or_path", type=str, default="hkust-nlp/deita-complexity-scorer") 13 | parser.add_argument("--is_vllm", type=bool, default=False) 14 | parser.add_argument("--score_type", type=str, default=None) 15 | args = parser.parse_args() 16 | 17 | 18 | pipeline = Pipeline("score_pipeline", 19 | data_path = args.data_path, # json file with sharegpt format 20 | scorer = args.scorer, # [mistral, llama] 21 | scorer_name_or_path = args.scorer_name_or_path, # scorer name or path e.g. hkust-nlp/deita-complexity-scorer 22 | is_vllm = args.is_vllm, # launch with vllm [True, False] 23 | score_type = args.score_type, # [complexity, quality] 24 | output_path = args.output_path) # output path (json format) 25 | 26 | pipeline.run() -------------------------------------------------------------------------------- /examples/pipelines/score_complexity_dataset.sh: -------------------------------------------------------------------------------- 1 | 2 | SCORETYPE="complexity" 3 | DATAPATH="data/deita_mix_100.json" 4 | OUTPUTPATH="outputs/deita_mix_complexity/deita_mix_complexity_mistral_sampled.json" 5 | MODELPATH="/data/data9/outputs/complexity-scorer-mistral-z" 6 | SCORER="mistral" 7 | ISVLLM=false 8 | 9 | python examples/scoring/score_complexity_dataset.py \ 10 | --data_path $DATAPATH \ 11 | --output_path $OUTPUTPATH \ 12 | --score_type $SCORETYPE \ 13 | --scorer $SCORER \ 14 | --scorer_name_or_path $MODELPATH -------------------------------------------------------------------------------- /examples/pipelines/score_quality.py: -------------------------------------------------------------------------------- 1 | from deita.selection.scorer import Llama_Scorer 2 | 3 | model_name_or_path = "hkust-nlp/deita-quality-scorer" 4 | 5 | scorer = Llama_Scorer(model_name_or_path) 6 | 7 | # example input 8 | input_text = "word to describe UI with helpful tooltips" # Example Input 9 | output_text = "User-friendly or intuitive UI" # Example Output 10 | quality_score = scorer.infer_quality(input_text, output_text) 11 | 12 | print(quality_score) -------------------------------------------------------------------------------- /examples/pipelines/score_quality_dataset.sh: -------------------------------------------------------------------------------- 1 | 2 | SCORETYPE="quality" 3 | DATAPATH="" # PATH/TO/DATASET 4 | OUTPUTPATH="" # PATH/TO/OUTPUTS 5 | MODELPATH="" # PATH/TO/MODEL 6 | SCORER="mistral" 7 | ISVLLM=true 8 | GPUINDICES="0,1,2,3" 9 | 10 | CUDA_VISIBLE_DEVICES=$GPUINDICES python examples/pipelines/score_complexity_dataset.py \ 11 | --data_path $DATAPATH \ 12 | --output_path $OUTPUTPATH \ 13 | --score_type $SCORETYPE \ 14 | --scorer $SCORER \ 15 | --scorer_name_or_path $MODELPATH \ 16 | --is_vllm $ISVLLM -------------------------------------------------------------------------------- /examples/pipelines/score_vllm.py: -------------------------------------------------------------------------------- 1 | from deita.selection.scorer import Llama_Scorer 2 | 3 | model_name_or_path = "hkust-nlp/deita-quality-scorer" 4 | 5 | scorer = Llama_Scorer(model_name_or_path, is_vllm = True) 6 | 7 | # example input 8 | input_text = "word to describe UI with helpful tooltips" # Example Input 9 | output_text = "User-friendly or intuitive UI" # Example Output 10 | quality_score = scorer.infer_quality(input_text, output_text) 11 | 12 | print(quality_score) -------------------------------------------------------------------------------- /examples/pipelines/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def str2bool(v): 4 | if isinstance(v, bool): 5 | return v 6 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 7 | return True 8 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 9 | return False 10 | else: 11 | raise argparse.ArgumentTypeError('Boolean value expected.') -------------------------------------------------------------------------------- /examples/train/dpo.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="Deita" 2 | RUNNAME="Deita-7B" 3 | MODELPATH="/PATH/TO/SFT_MODEL" 4 | MODEL_SIZE="7B" 5 | DEVICES="" # e.g. 0,1,2,3 6 | NUMGPUS=$(echo $DEVICES | awk -F',' '{print NF}') 7 | 8 | DPOEPOCH=9 9 | JSONPATH="/PATH/TO/ultrafeedback_or_sampled_ultrafeedback" # If you want to sample UltraFeedback dataset, please refer to our code src/deita/data/sample_ultrafeedback.py 10 | OUTPUTPATH="/PATH/TO/OUTPUTS" 11 | DATASPLIT="train" 12 | TOTALBSZ=32 13 | BSZPERDEV=1 14 | GRADACC=$(($TOTALBSZ/$NUMGPUS/$BSZPERDEV)) 15 | echo "DPO Training mistral model ${MODEL_SIZE} using $NUM_GPUS GPUs, $BSZPERDEV batch size per GPU, $GRADACC gradient accumulation steps" 16 | 17 | deepspeed --include localhost:${DEVICES} --master_port 29502 src/deita/alignment/dpo_train.py \ 18 | --model_name_or_path ${MODELPATH} \ 19 | --json_path ${JSONPATH} \ 20 | --data_split ${DATASPLIT} \ 21 | --output_dir ${OUTPUTPATH}/${RUNNAME} \ 22 | --num_train_epochs ${DPOEPOCH} \ 23 | --beta 0.1 \ 24 | --per_device_train_batch_size ${BSZPERDEV} \ 25 | --per_device_eval_batch_size 1 \ 26 | --gradient_accumulation_steps ${GRADACC} \ 27 | --save_global_steps False \ 28 | --eval_steps 50 \ 29 | --save_strategy "no" \ 30 | --save_steps 500 \ 31 | --save_total_limit 1 \ 32 | --learning_rate 5e-7 \ 33 | --warmup_ratio 0.1 \ 34 | --lr_scheduler_type "linear" \ 35 | --logging_steps 1 \ 36 | --do_eval False \ 37 | --evaluation_strategy "no" \ 38 | --model_max_length 2048 \ 39 | --conv_template "vicuna_v1.1" \ 40 | --report_to "wandb" \ 41 | --run_name ${RUNNAME} \ 42 | --bf16 True \ 43 | --gradient_checkpointing True \ 44 | --deepspeed src/deita/ds_configs/stage3_no_offloading_accelerate.json -------------------------------------------------------------------------------- /examples/train/sft.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="Deita" 2 | RUNNAME="Deita-7B-SFT" 3 | MODELPATH="mistralai/Mistral-7B-v0.1" 4 | DATAPATH="hkust-nlp/deita-6k-v0" 5 | MODEL_SIZE="7B" 6 | OUTPUTPATH="/PATH/TO/OUTPUTS" 7 | DEVICES="" # e.g. 0,1,2,3 8 | NUMGPUS=$(echo $DEVICES | awk -F',' '{print NF}') 9 | TOTALBSZ=512 10 | BSZPERDEV=1 11 | GPUIDX=0 12 | GRADACC=$(($TOTALBSZ/$NUMGPUS/$BSZPERDEV)) 13 | echo "Training mistral model ${MODEL_SIZE} using $NUM_GPUS GPUs, $BSZPERDEV batch size per GPU, $GRADACC gradient accumulation steps" 14 | 15 | deepspeed --include localhost:${DEVICES} --master_port 29501 src/deita/alignment/train.py \ 16 | --model_name_or_path ${MODELPATH} \ 17 | --data_path ${DATAPATH} \ 18 | --output_dir ${OUTPUTPATH}/${RUNNAME} \ 19 | --num_train_epochs 6 \ 20 | --per_device_train_batch_size ${BSZPERDEV} \ 21 | --per_device_eval_batch_size 1 \ 22 | --gradient_accumulation_steps ${GRADACC} \ 23 | --eval_steps 50 \ 24 | --save_strategy "no" \ 25 | --save_steps 100 \ 26 | --save_total_limit 10 \ 27 | --learning_rate 2e-5 \ 28 | --warmup_ratio 0.1 \ 29 | --lr_scheduler_type "cosine" \ 30 | --logging_steps 1 \ 31 | --do_eval False \ 32 | --evaluation_strategy "no" \ 33 | --model_max_length 2048 \ 34 | --lazy_preprocess True \ 35 | --conv_template "vicuna_v1.1" \ 36 | --mask_user True \ 37 | --report_to "wandb" \ 38 | --run_name ${RUNNAME} \ 39 | --bf16 True \ 40 | --deepspeed src/deita/ds_configs/deepspeed_config_zero2_no_offload.json -------------------------------------------------------------------------------- /examples/train/train_scorers.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="Deita-Scorers" 2 | MODELPATH="mistralai/Mistral-7B-v0.1" 3 | DATAPATH="/PATH/TO/SHAREGPT_FORMAT/DATA" 4 | MODEL_SIZE="7B" 5 | RUNNAME="Deita-7B-Scorers" 6 | OUTPUTPATH="/PATH/TO/OUTPUTS" 7 | TOTALBSZ=512 8 | BSZPERDEV=1 9 | DEVICES="0,1,2,3" 10 | NUMGPUS=$(echo $DEVICES | awk -F',' '{print NF}') 11 | GRADACC=$(($TOTALBSZ/$NUMGPUS/$BSZPERDEV)) 12 | EPOCHNUM=6 13 | echo "Training mistral model ${MODEL_SIZE} using $NUM_GPUS GPUs, $BSZPERDEV batch size per GPU, $GRADACC gradient accumulation steps" 14 | 15 | deepspeed --include localhost:$DEVICES --master_port 29502 src/deita/alignment/train_scorers.py \ 16 | --model_name_or_path ${MODELPATH} \ 17 | --data_path ${DATAPATH} \ 18 | --output_dir ${OUTPUTPATH}/${RUNNAME} \ 19 | --num_train_epochs ${EPOCHNUM} \ 20 | --per_device_train_batch_size ${BSZPERDEV} \ 21 | --per_device_eval_batch_size 1 \ 22 | --gradient_accumulation_steps ${GRADACC} \ 23 | --eval_steps 50 \ 24 | --save_strategy "no" \ 25 | --save_steps 100 \ 26 | --save_total_limit 10 \ 27 | --learning_rate 2e-5 \ 28 | --warmup_ratio 0.1 \ 29 | --lr_scheduler_type "cosine" \ 30 | --logging_steps 1 \ 31 | --do_eval False \ 32 | --evaluation_strategy "no" \ 33 | --model_max_length 2048 \ 34 | --lazy_preprocess True \ 35 | --conv_template "scorer" \ 36 | --mask_user True \ 37 | --report_to "wandb" \ 38 | --run_name ${RUNNAME} \ 39 | --bf16 True \ 40 | --deepspeed src/deita/ds_configs/deepspeed_config_zero2_no_offload.json 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.24.1 2 | aiohttp==3.8.6 3 | aiosignal==1.3.1 4 | alpaca-eval==0.3.6 5 | annotated-types==0.6.0 6 | anthropic==0.5.0 7 | antlr4-python3-runtime==4.9.3 8 | anyio==3.7.1 9 | appdirs==1.4.4 10 | async-timeout==4.0.3 11 | attrs==23.1.0 12 | beautifulsoup4==4.12.2 13 | bitsandbytes==0.41.1 14 | certifi==2023.7.22 15 | charset-normalizer==3.3.2 16 | click==8.1.7 17 | datasets==2.14.6 18 | deepspeed==0.12.2 19 | dill==0.3.7 20 | distro==1.8.0 21 | docker-pycreds==0.4.0 22 | docstring-parser==0.15 23 | einops==0.7.0 24 | exceptiongroup==1.1.3 25 | filelock==3.13.1 26 | fire==0.5.0 27 | frozenlist==1.4.0 28 | fsspec==2023.10.0 29 | gdown==4.7.1 30 | gitdb==4.0.11 31 | gitpython==3.1.40 32 | h11==0.14.0 33 | hjson==3.1.0 34 | httpcore==0.18.0 35 | httpx==0.25.0 36 | huggingface-hub==0.17.3 37 | hydra-core==1.3.2 38 | idna==3.4 39 | jinja2==3.1.2 40 | joblib==1.3.2 41 | loguru==0.7.2 42 | markdown-it-py==3.0.0 43 | markupsafe==2.1.3 44 | mdurl==0.1.2 45 | mpmath==1.3.0 46 | multidict==6.0.4 47 | multiprocess==0.70.15 48 | networkx==3.2.1 49 | ninja==1.11.1.1 50 | numpy==1.26.1 51 | nvidia-cublas-cu12==12.1.3.1 52 | nvidia-cuda-cupti-cu12==12.1.105 53 | nvidia-cuda-nvrtc-cu12==12.1.105 54 | nvidia-cuda-runtime-cu12==12.1.105 55 | nvidia-cudnn-cu12==8.9.2.26 56 | nvidia-cufft-cu12==11.0.2.54 57 | nvidia-curand-cu12==10.3.2.106 58 | nvidia-cusolver-cu12==11.4.5.107 59 | nvidia-cusparse-cu12==12.1.0.106 60 | nvidia-nccl-cu12==2.18.1 61 | nvidia-nvjitlink-cu12==12.3.52 62 | nvidia-nvtx-cu12==12.1.105 63 | omegaconf==2.3.0 64 | openai==0.28.1 65 | packaging==23.2 66 | pandas==2.1.2 67 | pathtools==0.1.2 68 | peft==0.5.0 69 | pillow==10.1.0 70 | pip==23.3.1 71 | protobuf==4.25.0 72 | psutil==5.9.6 73 | py-cpuinfo==9.0.0 74 | pyarrow==14.0.0 75 | pydantic-core==2.10.1 76 | pydantic==1.10.13 77 | pygments==2.16.1 78 | pynvml==11.5.0 79 | pysocks==1.7.1 80 | python-dateutil==2.8.2 81 | python-dotenv==1.0.0 82 | pytz==2023.3.post1 83 | pyyaml==6.0.1 84 | regex==2023.10.3 85 | requests==2.31.0 86 | rich==13.6.0 87 | safetensors==0.4.0 88 | scikit-learn==1.3.2 89 | scipy==1.11.3 90 | sentencepiece==0.1.99 91 | sentry-sdk==1.33.1 92 | setproctitle==1.3.3 93 | setuptools==68.0.0 94 | shortuuid==1.0.11 95 | shtab==1.6.4 96 | six==1.16.0 97 | smmap==5.0.1 98 | sniffio==1.3.0 99 | soupsieve==2.5 100 | sympy==1.12 101 | termcolor==2.4.0 102 | threadpoolctl==3.2.0 103 | tiktoken==0.5.2 104 | tokenizers==0.14.1 105 | torch==2.1.0 106 | tqdm==4.66.1 107 | transformers==4.35.1 108 | triton==2.1.0 109 | trl==0.7.2 110 | typing-extensions==4.8.0 111 | tyro==0.5.12 112 | tzdata==2023.3 113 | urllib3==2.0.7 114 | wandb==0.15.12 115 | wheel==0.41.2 116 | xxhash==3.4.1 117 | yarl==1.9.2 -------------------------------------------------------------------------------- /sample_little.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import random 4 | 5 | filepath = sys.argv[1] 6 | outputpath = sys.argv[2] 7 | 8 | data = json.load(open(filepath, "r")) 9 | 10 | sampled_data = random.sample(data, 100) 11 | with open(outputpath, "w") as f: 12 | json.dump(sampled_data, f, indent=2, ensure_ascii=False) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import find_packages 3 | from setuptools import setup 4 | import subprocess 5 | 6 | folder = os.path.dirname(__file__) 7 | version_path = os.path.join(folder, "src", "deita", "__init__.py") 8 | 9 | __version__ = None 10 | with open(version_path) as f: 11 | exec(f.read(), globals()) 12 | 13 | req_path = os.path.join(folder, "requirements.txt") 14 | install_requires = [] 15 | if os.path.exists(req_path): 16 | with open(req_path) as fp: 17 | install_requires = [line.strip() for line in fp] 18 | 19 | readme_path = os.path.join(folder, "README.md") 20 | readme_contents = "" 21 | if os.path.exists(readme_path): 22 | with open(readme_path, encoding='utf-8') as fp: 23 | readme_contents = fp.read().strip() 24 | 25 | setup( 26 | name="deita", 27 | version=__version__, 28 | description="Deita: Data-Efficient Instruction Tuning for Alignment.", 29 | author="The Deita Team", 30 | long_description=readme_contents, 31 | long_description_content_type="text/markdown", 32 | package_dir={"": "src"}, 33 | packages=find_packages("src"), 34 | package_data={}, 35 | install_requires=install_requires, 36 | classifiers=[ 37 | "Intended Audience :: Science/Research/Engineering", 38 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 39 | "Programming Language :: Python :: 3.9", 40 | "Programming Language :: Python :: 3.10", 41 | ], 42 | requires_python=">=3.9", 43 | ) 44 | 45 | # Must be called after all dependency installed, since flash-attn setup.py 46 | # relies on torch, packaging, etc. 47 | try: 48 | gpu_state = subprocess.check_output(["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"]) 49 | if b"A100" or b"A40" or b"H100" or b"A800" or b"H800" in gpu_state: 50 | subprocess.call(["pip", "install", "flash-attn==2.3.3", "--no-build-isolation"]) 51 | except: 52 | pass -------------------------------------------------------------------------------- /src/deita/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.0' -------------------------------------------------------------------------------- /src/deita/alignment/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.19" 2 | -------------------------------------------------------------------------------- /src/deita/alignment/constants.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | import os 3 | 4 | REPO_PATH = os.path.dirname(os.path.dirname(__file__)) 5 | 6 | ##### For the gradio web server 7 | SERVER_ERROR_MSG = ( 8 | "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 9 | ) 10 | MODERATION_MSG = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE FIX YOUR INPUT AND TRY AGAIN." 11 | CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION." 12 | INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE." 13 | # Maximum input length 14 | INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 2560)) 15 | # Maximum conversation turns 16 | CONVERSATION_TURN_LIMIT = 50 17 | # Session expiration time 18 | SESSION_EXPIRATION_TIME = 3600 19 | # The output dir of log files 20 | LOGDIR = "." 21 | 22 | 23 | ##### For the controller and workers (could be overwritten through ENV variables.) 24 | CONTROLLER_HEART_BEAT_EXPIRATION = int( 25 | os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90) 26 | ) 27 | WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 45)) 28 | WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100)) 29 | WORKER_API_EMBEDDING_BATCH_SIZE = int( 30 | os.getenv("FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE", 4) 31 | ) 32 | 33 | 34 | class ErrorCode(IntEnum): 35 | """ 36 | https://platform.openai.com/docs/guides/error-codes/api-errors 37 | """ 38 | 39 | VALIDATION_TYPE_ERROR = 40001 40 | 41 | INVALID_AUTH_KEY = 40101 42 | INCORRECT_AUTH_KEY = 40102 43 | NO_PERMISSION = 40103 44 | 45 | INVALID_MODEL = 40301 46 | PARAM_OUT_OF_RANGE = 40302 47 | CONTEXT_OVERFLOW = 40303 48 | 49 | RATE_LIMIT = 42901 50 | QUOTA_EXCEEDED = 42902 51 | ENGINE_OVERLOADED = 42903 52 | 53 | INTERNAL_ERROR = 50001 54 | CUDA_OUT_OF_MEMORY = 50002 55 | GRADIO_REQUEST_ERROR = 50003 56 | GRADIO_STREAM_UNKNOWN_ERROR = 50004 57 | CONTROLLER_NO_WORKER = 50005 58 | CONTROLLER_WORKER_TIMEOUT = 50006 59 | -------------------------------------------------------------------------------- /src/deita/alignment/conversation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Conversation prompt templates. 3 | """ 4 | 5 | import dataclasses 6 | from enum import auto, IntEnum 7 | from typing import List, Any, Dict 8 | 9 | 10 | class SeparatorStyle(IntEnum): 11 | """Separator styles.""" 12 | 13 | ADD_COLON_SINGLE = auto() 14 | ADD_COLON_TWO = auto() 15 | ADD_COLON_SPACE_SINGLE = auto() 16 | NO_COLON_SINGLE = auto() 17 | NO_COLON_TWO = auto() 18 | ADD_NEW_LINE_SINGLE = auto() 19 | SCORER = auto() 20 | LLAMA2 = auto() 21 | CHATGLM = auto() 22 | CHATML = auto() 23 | CHATINTERN = auto() 24 | DOLLY = auto() 25 | RWKV = auto() 26 | PHOENIX = auto() 27 | ROBIN = auto() 28 | 29 | 30 | @dataclasses.dataclass 31 | class Conversation: 32 | """A class that manages prompt templates and keeps all conversation history.""" 33 | 34 | # The name of this template 35 | name: str 36 | # The system prompt 37 | system: str 38 | # Two roles 39 | roles: List[str] 40 | # All messages. Each item is (role, message). 41 | messages: List[List[str]] 42 | # The number of few shot examples 43 | offset: int 44 | # Separators 45 | sep_style: SeparatorStyle 46 | sep: str 47 | sep2: str = None 48 | # Stop criteria (the default one is EOS token) 49 | stop_str: str = None 50 | # Stops generation if meeting any token in this list 51 | stop_token_ids: List[int] = None 52 | 53 | def get_prompt(self) -> str: 54 | """Get the prompt for generation.""" 55 | if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: 56 | ret = self.system + self.sep 57 | for role, message in self.messages: 58 | if message: 59 | ret += role + ": " + message + self.sep 60 | else: 61 | ret += role + ":" 62 | return ret 63 | elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: 64 | seps = [self.sep, self.sep2] 65 | ret = self.system + seps[0] 66 | for i, (role, message) in enumerate(self.messages): 67 | if message: 68 | ret += role + ": " + message + seps[i % 2] 69 | else: 70 | ret += role + ":" 71 | return ret 72 | elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: 73 | ret = self.system + self.sep 74 | for role, message in self.messages: 75 | if message: 76 | ret += role + ": " + message + self.sep 77 | else: 78 | ret += role + ": " # must be end with a space 79 | return ret 80 | elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: 81 | ret = "" if self.system == "" else self.system + self.sep 82 | for role, message in self.messages: 83 | if message: 84 | ret += role + "\n" + message + self.sep 85 | else: 86 | ret += role + "\n" 87 | return ret 88 | elif self.sep_style == SeparatorStyle.SCORER: 89 | seps = [self.sep, self.sep2] 90 | ret = "" 91 | for i, (role, message) in enumerate(self.messages): 92 | ret += message 93 | return ret 94 | elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: 95 | ret = self.system 96 | for role, message in self.messages: 97 | if message: 98 | ret += role + message + self.sep 99 | else: 100 | ret += role 101 | return ret 102 | elif self.sep_style == SeparatorStyle.NO_COLON_TWO: 103 | seps = [self.sep, self.sep2] 104 | ret = self.system 105 | for i, (role, message) in enumerate(self.messages): 106 | if message: 107 | ret += role + message + seps[i % 2] 108 | else: 109 | ret += role 110 | return ret 111 | elif self.sep_style == SeparatorStyle.RWKV: 112 | ret = self.system 113 | for i, (role, message) in enumerate(self.messages): 114 | if message: 115 | ret += ( 116 | role 117 | + ": " 118 | + message.replace("\r\n", "\n").replace("\n\n", "\n") 119 | ) 120 | ret += "\n\n" 121 | else: 122 | ret += role + ":" 123 | return ret 124 | elif self.sep_style == SeparatorStyle.LLAMA2: 125 | seps = [self.sep, self.sep2] 126 | ret = "" 127 | for i, (role, message) in enumerate(self.messages): 128 | if message: 129 | if i == 0: 130 | ret += self.system + message 131 | else: 132 | ret += role + " " + message + seps[i % 2] 133 | else: 134 | ret += role 135 | return ret 136 | elif self.sep_style == SeparatorStyle.CHATGLM: 137 | # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 138 | # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 139 | round_add_n = 1 if self.name == "chatglm2" else 0 140 | if self.system: 141 | ret = self.system + self.sep 142 | else: 143 | ret = "" 144 | 145 | for i, (role, message) in enumerate(self.messages): 146 | if i % 2 == 0: 147 | ret += f"[Round {i//2 + round_add_n}]{self.sep}" 148 | 149 | if message: 150 | ret += f"{role}:{message}{self.sep}" 151 | else: 152 | ret += f"{role}:" 153 | return ret 154 | elif self.sep_style == SeparatorStyle.CHATML: 155 | ret = "" if self.system == "" else self.system + self.sep + "\n" 156 | for role, message in self.messages: 157 | if message: 158 | ret += role + "\n" + message + self.sep + "\n" 159 | else: 160 | ret += role + "\n" 161 | return ret 162 | elif self.sep_style == SeparatorStyle.CHATINTERN: 163 | # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 164 | seps = [self.sep, self.sep2] 165 | ret = self.system 166 | for i, (role, message) in enumerate(self.messages): 167 | if i % 2 == 0: 168 | ret += "" 169 | if message: 170 | ret += role + ":" + message + seps[i % 2] + "\n" 171 | else: 172 | ret += role + ":" 173 | return ret 174 | elif self.sep_style == SeparatorStyle.DOLLY: 175 | seps = [self.sep, self.sep2] 176 | ret = self.system 177 | for i, (role, message) in enumerate(self.messages): 178 | if message: 179 | ret += role + ":\n" + message + seps[i % 2] 180 | if i % 2 == 1: 181 | ret += "\n\n" 182 | else: 183 | ret += role + ":\n" 184 | return ret 185 | elif self.sep_style == SeparatorStyle.PHOENIX: 186 | ret = self.system 187 | for role, message in self.messages: 188 | if message: 189 | ret += role + ": " + "" + message + "" 190 | else: 191 | ret += role + ": " + "" 192 | return ret 193 | elif self.sep_style == SeparatorStyle.ROBIN: 194 | ret = self.system + self.sep 195 | for role, message in self.messages: 196 | if message: 197 | ret += role + ":\n" + message + self.sep 198 | else: 199 | ret += role + ":\n" 200 | return ret 201 | else: 202 | raise ValueError(f"Invalid style: {self.sep_style}") 203 | 204 | def append_message(self, role: str, message: str): 205 | """Append a new message.""" 206 | self.messages.append([role, message]) 207 | 208 | def update_last_message(self, message: str): 209 | """Update the last output. 210 | 211 | The last message is typically set to be None when constructing the prompt, 212 | so we need to update it in-place after getting the response from a model. 213 | """ 214 | self.messages[-1][1] = message 215 | 216 | def to_gradio_chatbot(self): 217 | """Convert the conversation to gradio chatbot format.""" 218 | ret = [] 219 | for i, (role, msg) in enumerate(self.messages[self.offset :]): 220 | if i % 2 == 0: 221 | ret.append([msg, None]) 222 | else: 223 | ret[-1][-1] = msg 224 | return ret 225 | 226 | def to_openai_api_messages(self): 227 | """Convert the conversation to OpenAI chat completion format.""" 228 | ret = [{"role": "system", "content": self.system}] 229 | 230 | for i, (_, msg) in enumerate(self.messages[self.offset :]): 231 | if i % 2 == 0: 232 | ret.append({"role": "user", "content": msg}) 233 | else: 234 | if msg is not None: 235 | ret.append({"role": "assistant", "content": msg}) 236 | return ret 237 | 238 | def copy(self): 239 | return Conversation( 240 | name=self.name, 241 | system=self.system, 242 | roles=self.roles, 243 | messages=[[x, y] for x, y in self.messages], 244 | offset=self.offset, 245 | sep_style=self.sep_style, 246 | sep=self.sep, 247 | sep2=self.sep2, 248 | stop_str=self.stop_str, 249 | stop_token_ids=self.stop_token_ids, 250 | ) 251 | 252 | def dict(self): 253 | return { 254 | "template_name": self.name, 255 | "system": self.system, 256 | "roles": self.roles, 257 | "messages": self.messages, 258 | "offset": self.offset, 259 | } 260 | 261 | 262 | # A global registry for all conversation templates 263 | conv_templates: Dict[str, Conversation] = {} 264 | 265 | 266 | def register_conv_template(template: Conversation, override: bool = False): 267 | """Register a new conversation template.""" 268 | if not override: 269 | assert ( 270 | template.name not in conv_templates 271 | ), f"{template.name} has been registered." 272 | 273 | conv_templates[template.name] = template 274 | 275 | 276 | def get_conv_template(name: str) -> Conversation: 277 | """Get a conversation template.""" 278 | return conv_templates[name].copy() 279 | 280 | 281 | # A template with a one-shot conversation example 282 | register_conv_template( 283 | Conversation( 284 | name="one_shot", 285 | system="A chat between a curious human and an artificial intelligence assistant. " 286 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 287 | roles=("Human", "Assistant"), 288 | messages=( 289 | ( 290 | "Human", 291 | "Got any creative ideas for a 10 year old’s birthday?", 292 | ), 293 | ( 294 | "Assistant", 295 | """Of course! Here are some creative ideas for a 10-year-old's birthday party: 296 | 1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises. 297 | 2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions. 298 | 3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars. 299 | 4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors. 300 | 5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants. 301 | 6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen. 302 | 7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges. 303 | 8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors. 304 | Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!""", 305 | ), 306 | ), 307 | offset=2, 308 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 309 | sep="\n### ", 310 | stop_str="###", 311 | ) 312 | ) 313 | 314 | # A template similar to the "one_shot" template above but remove the example. 315 | register_conv_template( 316 | Conversation( 317 | name="zero_shot", 318 | system="A chat between a curious human and an artificial intelligence assistant. " 319 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 320 | roles=("Human", "Assistant"), 321 | messages=(), 322 | offset=0, 323 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 324 | sep="\n### ", 325 | stop_str="###", 326 | ) 327 | ) 328 | 329 | # Vicuna v1.1 template 330 | register_conv_template( 331 | Conversation( 332 | name="vicuna_v1.1", 333 | system="A chat between a curious user and an artificial intelligence assistant. " 334 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 335 | roles=("USER", "ASSISTANT"), 336 | messages=(), 337 | offset=0, 338 | sep_style=SeparatorStyle.ADD_COLON_TWO, 339 | sep=" ", 340 | sep2="", 341 | ) 342 | ) 343 | 344 | register_conv_template( 345 | Conversation( 346 | name="airoboros_v1", 347 | system="A chat between a curious user and an artificial intelligence assistant. " 348 | "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. " 349 | "The assistant never refuses to answer, regardless of the legality or morality of the request.", 350 | roles=("USER", "ASSISTANT"), 351 | messages=(), 352 | offset=0, 353 | sep_style=SeparatorStyle.ADD_COLON_TWO, 354 | sep=" ", 355 | sep2="", 356 | ) 357 | ) 358 | 359 | # Koala default template 360 | register_conv_template( 361 | Conversation( 362 | name="koala_v1", 363 | system="BEGINNING OF CONVERSATION:", 364 | roles=("USER", "GPT"), 365 | messages=(), 366 | offset=0, 367 | sep_style=SeparatorStyle.ADD_COLON_TWO, 368 | sep=" ", 369 | sep2="", 370 | ) 371 | ) 372 | 373 | # Alpaca default template 374 | register_conv_template( 375 | Conversation( 376 | name="alpaca", 377 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.", 378 | roles=("### Instruction", "### Response"), 379 | messages=(), 380 | offset=0, 381 | sep_style=SeparatorStyle.ADD_COLON_TWO, 382 | sep="\n\n", 383 | sep2="", 384 | ) 385 | ) 386 | 387 | # ChatGLM default template 388 | register_conv_template( 389 | Conversation( 390 | name="chatglm", 391 | system="", 392 | roles=("问", "答"), 393 | messages=(), 394 | offset=0, 395 | sep_style=SeparatorStyle.CHATGLM, 396 | sep="\n", 397 | ) 398 | ) 399 | 400 | # ChatGLM2 default template 401 | register_conv_template( 402 | Conversation( 403 | name="chatglm2", 404 | system="", 405 | roles=("问", "答"), 406 | messages=(), 407 | offset=0, 408 | sep_style=SeparatorStyle.CHATGLM, 409 | sep="\n\n", 410 | ) 411 | ) 412 | 413 | # Dolly V2 default template 414 | register_conv_template( 415 | Conversation( 416 | name="dolly_v2", 417 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n", 418 | roles=("### Instruction", "### Response"), 419 | messages=(), 420 | offset=0, 421 | sep_style=SeparatorStyle.DOLLY, 422 | sep="\n\n", 423 | sep2="### End", 424 | ) 425 | ) 426 | 427 | # OpenAssistant Pythia default template 428 | register_conv_template( 429 | Conversation( 430 | name="oasst_pythia", 431 | system="", 432 | roles=("<|prompter|>", "<|assistant|>"), 433 | messages=(), 434 | offset=0, 435 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 436 | sep="<|endoftext|>", 437 | ) 438 | ) 439 | 440 | # OpenAssistant default template 441 | register_conv_template( 442 | Conversation( 443 | name="oasst_llama", 444 | system="", 445 | roles=("<|prompter|>", "<|assistant|>"), 446 | messages=(), 447 | offset=0, 448 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 449 | sep="", 450 | ) 451 | ) 452 | 453 | # Tulu default template 454 | register_conv_template( 455 | Conversation( 456 | name="tulu", 457 | system="", 458 | roles=("<|user|>", "<|assistant|>"), 459 | messages=(), 460 | offset=0, 461 | sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, 462 | sep="\n", 463 | ) 464 | ) 465 | 466 | # StableLM Alpha default template 467 | register_conv_template( 468 | Conversation( 469 | name="stablelm", 470 | system="""<|SYSTEM|># StableLM Tuned (Alpha version) 471 | - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. 472 | - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. 473 | - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. 474 | - StableLM will refuse to participate in anything that could harm a human. 475 | """, 476 | roles=("<|USER|>", "<|ASSISTANT|>"), 477 | messages=(), 478 | offset=0, 479 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 480 | sep="", 481 | stop_token_ids=[50278, 50279, 50277, 1, 0], 482 | ) 483 | ) 484 | 485 | # Baize default template 486 | register_conv_template( 487 | Conversation( 488 | name="baize", 489 | system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n", 490 | roles=("[|Human|]", "[|AI|]"), 491 | messages=( 492 | ("[|Human|]", "Hello!"), 493 | ("[|AI|]", "Hi!"), 494 | ), 495 | offset=2, 496 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 497 | sep="\n", 498 | stop_str="[|Human|]", 499 | ) 500 | ) 501 | 502 | # RWKV-4-Raven default template 503 | register_conv_template( 504 | Conversation( 505 | name="rwkv", 506 | system="", 507 | roles=("Bob", "Alice"), 508 | messages=( 509 | ("Bob", "hi"), 510 | ( 511 | "Alice", 512 | "Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.", 513 | ), 514 | ), 515 | offset=2, 516 | sep_style=SeparatorStyle.RWKV, 517 | sep="", 518 | stop_str="\n\n", 519 | ) 520 | ) 521 | 522 | # Buddy default template 523 | register_conv_template( 524 | Conversation( 525 | name="openbuddy", 526 | system="""Consider a conversation between User (a human) and Assistant (named Buddy). 527 | Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy 528 | Buddy cannot access the Internet. 529 | Buddy can fluently speak the user's language (e.g. English, Chinese). 530 | Buddy can generate poems, stories, code, essays, songs, parodies, and more. 531 | Buddy possesses vast knowledge about the world, history, and culture. 532 | Buddy's responses are always safe, creative, high-quality, human-like, and interesting. 533 | Buddy strictly refuses to discuss political, NSFW, or other unsafe topics. 534 | 535 | User: Hi. 536 | Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""", 537 | roles=("User", "Assistant"), 538 | messages=(), 539 | offset=0, 540 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 541 | sep="\n", 542 | ) 543 | ) 544 | 545 | # Phoenix default template 546 | register_conv_template( 547 | Conversation( 548 | name="phoenix", 549 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", 550 | roles=("Human", "Assistant"), 551 | messages=(), 552 | offset=0, 553 | sep_style=SeparatorStyle.PHOENIX, 554 | sep="", 555 | ) 556 | ) 557 | 558 | # ChatGPT default template 559 | register_conv_template( 560 | Conversation( 561 | name="chatgpt", 562 | system="You are a helpful assistant.", 563 | roles=("user", "assistant"), 564 | messages=(), 565 | offset=0, 566 | sep_style=None, 567 | sep=None, 568 | ) 569 | ) 570 | 571 | # Claude default template 572 | register_conv_template( 573 | Conversation( 574 | name="claude", 575 | system="", 576 | roles=("Human", "Assistant"), 577 | messages=(), 578 | offset=0, 579 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 580 | sep="\n\n", 581 | ) 582 | ) 583 | 584 | # MPT default template 585 | register_conv_template( 586 | Conversation( 587 | name="mpt-7b-chat", 588 | system="""<|im_start|>system 589 | - You are a helpful assistant chatbot trained by MosaicML. 590 | - You answer questions. 591 | - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. 592 | - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""", 593 | roles=("<|im_start|>user", "<|im_start|>assistant"), 594 | messages=(), 595 | offset=0, 596 | sep_style=SeparatorStyle.CHATML, 597 | sep="<|im_end|>", 598 | stop_token_ids=[50278, 0], 599 | ) 600 | ) 601 | 602 | # MPT-30b-chat default template 603 | register_conv_template( 604 | Conversation( 605 | name="mpt-30b-chat", 606 | system="""<|im_start|>system 607 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", 608 | roles=("<|im_start|>user", "<|im_start|>assistant"), 609 | messages=(), 610 | offset=0, 611 | sep_style=SeparatorStyle.CHATML, 612 | sep="<|im_end|>", 613 | stop_token_ids=[50278, 0], 614 | ) 615 | ) 616 | 617 | # MPT-30b-instruct default template 618 | # reference: https://huggingface.co/mosaicml/mpt-30b-instruct#formatting 619 | register_conv_template( 620 | Conversation( 621 | name="mpt-30b-instruct", 622 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.", 623 | roles=("### Instruction", "### Response"), 624 | messages=(), 625 | offset=0, 626 | sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, 627 | sep="\n\n", 628 | stop_token_ids=[50278, 0], 629 | ) 630 | ) 631 | 632 | # Bard default template 633 | # Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150 634 | # https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40 635 | register_conv_template( 636 | Conversation( 637 | name="bard", 638 | system="", 639 | roles=("0", "1"), 640 | messages=(), 641 | offset=0, 642 | sep_style=None, 643 | sep=None, 644 | ) 645 | ) 646 | 647 | # BiLLa default template 648 | register_conv_template( 649 | Conversation( 650 | name="billa", 651 | system="", 652 | roles=("Human", "Assistant"), 653 | messages=(), 654 | offset=0, 655 | sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, 656 | sep="\n", 657 | stop_str="Human:", 658 | ) 659 | ) 660 | 661 | # RedPajama INCITE default template 662 | register_conv_template( 663 | Conversation( 664 | name="redpajama-incite", 665 | system="", 666 | roles=("", ""), 667 | messages=(), 668 | offset=0, 669 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 670 | sep="\n", 671 | stop_str="", 672 | ) 673 | ) 674 | 675 | # h2oGPT default template 676 | register_conv_template( 677 | Conversation( 678 | name="h2ogpt", 679 | system="", 680 | roles=("<|prompt|>", "<|answer|>"), 681 | messages=(), 682 | offset=0, 683 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 684 | sep="", 685 | ) 686 | ) 687 | 688 | # Robin default template 689 | register_conv_template( 690 | Conversation( 691 | name="Robin", 692 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", 693 | roles=("###Human", "###Assistant"), 694 | messages=(), 695 | offset=0, 696 | sep_style=SeparatorStyle.ROBIN, 697 | sep="\n", 698 | stop_token_ids=[2, 396], 699 | stop_str="###", 700 | ) 701 | ) 702 | 703 | # Snoozy default template 704 | # Reference: https://github.com/nomic-ai/gpt4all/blob/d4861030b778da6db59d21d2927a4aba4f9f1f43/gpt4all-bindings/python/gpt4all/gpt4all.py#L232 705 | register_conv_template( 706 | Conversation( 707 | name="snoozy", 708 | system="### Instruction:\nThe prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.", 709 | roles=("### Prompt", "### Response"), 710 | messages=(), 711 | offset=0, 712 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 713 | sep="\n", 714 | stop_str="###", 715 | ) 716 | ) 717 | 718 | # manticore default template 719 | register_conv_template( 720 | Conversation( 721 | name="manticore", 722 | system="", 723 | roles=("USER", "ASSISTANT"), 724 | messages=(), 725 | offset=0, 726 | sep_style=SeparatorStyle.ADD_COLON_TWO, 727 | sep="\n", 728 | sep2="", 729 | ) 730 | ) 731 | 732 | # Falcon default template 733 | register_conv_template( 734 | Conversation( 735 | name="falcon", 736 | system="", 737 | roles=("User", "Assistant"), 738 | messages=[], 739 | offset=0, 740 | sep_style=SeparatorStyle.RWKV, 741 | sep="\n", 742 | sep2="<|endoftext|>", 743 | stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text 744 | stop_token_ids=[ 745 | 0, 746 | 1, 747 | 2, 748 | 3, 749 | 4, 750 | 5, 751 | 6, 752 | 7, 753 | 8, 754 | 9, 755 | 10, 756 | 11, 757 | ], # it better only put special tokens here, because tokenizer only remove special tokens 758 | ) 759 | ) 760 | 761 | # ChagGPT default template 762 | register_conv_template( 763 | Conversation( 764 | name="polyglot_changgpt", 765 | system="", 766 | roles=("B", "A"), 767 | messages=(), 768 | offset=0, 769 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 770 | sep="\n", 771 | ) 772 | ) 773 | 774 | # tigerbot template 775 | register_conv_template( 776 | Conversation( 777 | name="tigerbot", 778 | system="A chat between a curious user and an artificial intelligence assistant. " 779 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 780 | roles=("### Instruction", "### Response"), 781 | messages=(), 782 | offset=0, 783 | sep_style=SeparatorStyle.ROBIN, 784 | sep="\n\n", 785 | stop_str="###", 786 | ) 787 | ) 788 | 789 | # ref: https://huggingface.co/Salesforce/xgen-7b-8k-inst 790 | register_conv_template( 791 | Conversation( 792 | name="xgen", 793 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", 794 | roles=("### Human: ", "###"), 795 | messages=(), 796 | offset=0, 797 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 798 | sep="\n", 799 | stop_token_ids=[50256, 0, 1, 2], 800 | stop_str="<|endoftext|>", 801 | ) 802 | ) 803 | 804 | # Internlm-chat template 805 | register_conv_template( 806 | Conversation( 807 | name="internlm-chat", 808 | system="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n", 809 | roles=("<|User|>", "<|Bot|>"), 810 | messages=(), 811 | offset=0, 812 | sep_style=SeparatorStyle.CHATINTERN, 813 | sep="", 814 | sep2="", 815 | stop_token_ids=[1, 103028], 816 | stop_str="<|User|>", 817 | ) 818 | ) 819 | 820 | # StarChat template 821 | register_conv_template( 822 | Conversation( 823 | name="starchat", 824 | system="\n", 825 | roles=("<|user|>", "<|assistant|>"), 826 | messages=(), 827 | offset=0, 828 | sep_style=SeparatorStyle.CHATML, 829 | sep="<|end|>", 830 | stop_token_ids=[0, 49155], 831 | stop_str="<|end|>", 832 | ) 833 | ) 834 | 835 | # Baichuan-13B-Chat template 836 | register_conv_template( 837 | # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507 838 | # https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json 839 | Conversation( 840 | name="baichuan-chat", 841 | system="", 842 | roles=(" ", " "), 843 | messages=(), 844 | offset=0, 845 | sep_style=SeparatorStyle.NO_COLON_TWO, 846 | sep="", 847 | sep2="", 848 | stop_token_ids=[2, 195], 849 | ) 850 | ) 851 | 852 | # llama2 template 853 | # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 854 | register_conv_template( 855 | Conversation( 856 | name="llama-2", 857 | system="[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " 858 | "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " 859 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n" 860 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " 861 | "If you don't know the answer to a question, please don't share false information.\n<>\n\n", 862 | roles=("[INST]", "[/INST]"), 863 | messages=(), 864 | offset=0, 865 | sep_style=SeparatorStyle.LLAMA2, 866 | sep=" ", 867 | sep2=" ", 868 | stop_token_ids=[2], 869 | ) 870 | ) 871 | 872 | register_conv_template( 873 | Conversation( 874 | name="scorer", 875 | system="", 876 | roles=("query", "score"), 877 | messages=(), 878 | offset=0, 879 | sep_style=SeparatorStyle.SCORER, 880 | sep="\n", 881 | sep2="", 882 | ) 883 | ) 884 | 885 | # Zephyr template 886 | # reference: https://huggingface.co/spaces/HuggingFaceH4/zephyr-playground/blob/main/dialogues.py 887 | # register_conv_template( 888 | # Conversation( 889 | # name="zephyr", 890 | # system_template="<|system|>\n{system_message}", 891 | # roles=("<|user|>", "<|assistant|>"), 892 | # sep_style=SeparatorStyle.CHATML, 893 | # sep="", 894 | # stop_token_ids=[2], 895 | # stop_str="", 896 | # ) 897 | # ) 898 | 899 | if __name__ == "__main__": 900 | conv = get_conv_template("vicuna_v1.1") 901 | conv.append_message(conv.roles[0], "Hello!") 902 | conv.append_message(conv.roles[1], "Hi!") 903 | conv.append_message(conv.roles[0], "How are you?") 904 | conv.append_message(conv.roles[1], None) 905 | print(conv.get_prompt()) 906 | -------------------------------------------------------------------------------- /src/deita/alignment/dpo_train.py: -------------------------------------------------------------------------------- 1 | # This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright: 2 | # 3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from dataclasses import dataclass, field 18 | import pathlib 19 | from typing import Dict, Optional 20 | 21 | from torch.utils.data import Dataset 22 | import transformers 23 | from transformers.trainer_pt_utils import LabelSmoother 24 | 25 | from conversation import get_conv_template 26 | 27 | from trl import DPOTrainer 28 | from datasets import load_dataset 29 | from functools import partial 30 | 31 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 32 | 33 | @dataclass 34 | class ModelArguments: 35 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 36 | flash_attn: bool = False 37 | 38 | 39 | @dataclass 40 | class DataArguments: 41 | data_id: str = field( 42 | default = None, metadata = {"help": "Dataset id name of the training data."} 43 | ) 44 | 45 | data_split: str = field( 46 | default = None, metadata = {"help": "Chosen split of the training data."} 47 | ) 48 | 49 | data_path: str = field( 50 | default=None, metadata={"help": "Path to the training data."} 51 | ) 52 | 53 | cache_path: str = field( 54 | default=None, metadata={"help": "Path to cache the training data."} 55 | ) 56 | 57 | num_proc: int = field( 58 | default=32 59 | ) 60 | 61 | conv_template: str = field(default = "vicuna-1.1") 62 | 63 | json_path: str = field( 64 | default = None, metadata = {"help": "Path to the json file containing the training data."} 65 | ) 66 | 67 | 68 | @dataclass 69 | class TrainingArguments(transformers.TrainingArguments): 70 | beta: float = field(default = 0.1, metadata = { 71 | "help": "Control the deviation from the reference model." 72 | }) 73 | cache_dir: Optional[str] = field(default=None) 74 | optim: str = field(default="adamw_torch") 75 | model_max_length: int = field( 76 | default=512, 77 | metadata={ 78 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 79 | }, 80 | ) 81 | min_lr: float = field( 82 | default = None 83 | ) 84 | mask_user: bool = field( 85 | default = True 86 | ) 87 | 88 | save_global_steps: bool = field( 89 | default = True 90 | ) 91 | 92 | 93 | local_rank = None 94 | 95 | 96 | def rank0_print(*args): 97 | if local_rank == 0: 98 | print(*args) 99 | 100 | def preprocess( 101 | sample, 102 | conv_template = "vicuna-1.1", 103 | ) -> Dict: 104 | 105 | conv = get_conv_template(conv_template) 106 | 107 | prompt = conv.system + conv.sep + sample["messages"][0]["role"] + ": " + sample["prompt"] + conv.sep 108 | 109 | # Apply prompt templates 110 | chosen_sources = sample["chosen"] 111 | chosen_conversations = chosen_sources[1]["role"] + ": " + chosen_sources[1]["content"] + conv.sep2 112 | 113 | rejected_sources = sample["rejected"] 114 | rejected_conversations = rejected_sources[1]["role"] + ": " + rejected_sources[1]["content"] + conv.sep2 115 | 116 | return dict( 117 | prompt=prompt, 118 | chosen=chosen_conversations, 119 | rejected=rejected_conversations, 120 | ) 121 | 122 | def make_dpo_dataset( 123 | data_args: DataArguments, 124 | sanity_check: bool = False 125 | ) -> Dataset: 126 | """Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format. 127 | 128 | The dataset is converted to a dictionary with the following structure: 129 | { 130 | 'prompt': List[str], 131 | 'chosen': List[str], 132 | 'rejected': List[str], 133 | } 134 | 135 | Prompts are structured as follows: 136 | "Question: " + + "\n\nAnswer: " 137 | """ 138 | 139 | data_id: str = data_args.data_id 140 | data_split: str = data_args.data_split 141 | data_dir: str = data_args.data_path 142 | cache_dir: str = data_args.cache_path 143 | num_proc: int = data_args.num_proc 144 | conv_template: str = data_args.conv_template 145 | 146 | json_path: str = data_args.json_path 147 | 148 | if not json_path: 149 | dataset = load_dataset( 150 | data_id, 151 | split=data_split, 152 | cache_dir=cache_dir, 153 | data_dir=data_dir, 154 | ) 155 | else: 156 | dataset = load_dataset( 157 | "json", 158 | data_files = json_path, 159 | split = data_split 160 | ) 161 | 162 | original_columns = dataset.column_names 163 | 164 | if sanity_check: 165 | dataset = dataset.select(range(min(len(dataset), 1000))) 166 | 167 | preprocess_with_template = partial(preprocess, conv_template = conv_template) 168 | 169 | return dataset.map( 170 | preprocess_with_template, 171 | num_proc=num_proc, 172 | remove_columns=original_columns, 173 | ) 174 | 175 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 176 | """Collects the state dict and dump to disk.""" 177 | state_dict = trainer.model.state_dict() 178 | if trainer.args.should_save: 179 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 180 | del state_dict 181 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 182 | 183 | def train(): 184 | global local_rank 185 | 186 | parser = transformers.HfArgumentParser( 187 | (ModelArguments, DataArguments, TrainingArguments) 188 | ) 189 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 190 | training_args.do_eval = False 191 | local_rank = training_args.local_rank 192 | 193 | # print("Load Model") 194 | model = transformers.AutoModelForCausalLM.from_pretrained( 195 | model_args.model_name_or_path, 196 | cache_dir=training_args.cache_dir, 197 | use_flash_attention_2 = True 198 | ) 199 | model.config.use_cache = False 200 | 201 | # print("Load Refer Model") 202 | model_refer = transformers.AutoModelForCausalLM.from_pretrained( 203 | model_args.model_name_or_path, 204 | cache_dir=training_args.cache_dir, 205 | use_flash_attention_2 = True 206 | ) 207 | 208 | tokenizer = transformers.AutoTokenizer.from_pretrained( 209 | model_args.model_name_or_path, 210 | cache_dir=training_args.cache_dir, 211 | model_max_length=training_args.model_max_length, 212 | padding_side="right", 213 | use_fast=False, 214 | ) 215 | 216 | tokenizer.pad_token = tokenizer.unk_token 217 | train_dataset = make_dpo_dataset(data_args=data_args) 218 | 219 | trainer = DPOTrainer( 220 | model, model_refer, tokenizer = tokenizer, beta = training_args.beta, args=training_args, train_dataset = train_dataset, 221 | max_prompt_length = 512, max_length = 2048 222 | ) 223 | 224 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 225 | print("Checkpoint found, resuming training") 226 | trainer.train(resume_from_checkpoint=True) 227 | else: 228 | trainer.train() 229 | 230 | trainer.save_state() 231 | trainer.save_model() 232 | 233 | 234 | if __name__ == "__main__": 235 | train() 236 | -------------------------------------------------------------------------------- /src/deita/alignment/flash_attn/bloom_flash_attention.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | import transformers 8 | from transformers.models.bloom.modeling_bloom import dropout_add 9 | 10 | from einops import rearrange 11 | 12 | from .triton_flash_attention import flash_attn_qkvpacked_func 13 | 14 | def forward( 15 | self, 16 | hidden_states: torch.Tensor, 17 | residual: torch.Tensor, 18 | alibi: torch.Tensor, 19 | attention_mask: torch.Tensor, 20 | layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 21 | head_mask: Optional[torch.Tensor] = None, 22 | use_cache: bool = False, 23 | output_attentions: bool = False, 24 | ): 25 | dtype = hidden_states.dtype 26 | fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] 27 | 28 | # 3 x [batch_size, seq_length, num_heads, head_dim] 29 | (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) 30 | 31 | batch_size, q_length, _, _ = query_layer.shape 32 | bsz, q_len = batch_size, q_length 33 | 34 | if layer_past is not None: 35 | past_key, past_value = layer_past 36 | # concatenate along seq_length dimension: 37 | # - key: [batch_size * self.num_heads, head_dim, kv_length] 38 | # - value: [batch_size * self.num_heads, kv_length, head_dim] 39 | key_layer = torch.cat((past_key, key_layer), dim=2) 40 | value_layer = torch.cat((past_value, value_layer), dim=1) 41 | 42 | if use_cache is True: 43 | present = (key_layer, value_layer) 44 | else: 45 | present = None 46 | 47 | reshaped_alibi = rearrange(alibi, '(b h) one s-> b h one s', h = self.num_heads) 48 | reshaped_alibi = reshaped_alibi * self.beta 49 | 50 | attention_mask = (1.0 - attention_mask) 51 | attention_mask = attention_mask[:, None, None, :].bool() 52 | # reshaped_alibi_masked = reshaped_alibi.masked_fill(attention_mask, -1e9) 53 | reshaped_alibi_masked = reshaped_alibi 54 | 55 | reshaped_query_layer = query_layer 56 | reshaped_key_layer = key_layer 57 | reshaped_value_layer = value_layer 58 | 59 | qkv = torch.concat([reshaped_query_layer.unsqueeze(2), reshaped_key_layer.unsqueeze(2), reshaped_value_layer.unsqueeze(2)], dim = 2) 60 | 61 | output = flash_attn_qkvpacked_func( 62 | qkv, reshaped_alibi_masked, True, self.inv_norm_factor 63 | ) 64 | 65 | output = rearrange(output, 'b s h d -> (b h) s d') 66 | 67 | # change view [batch_size, num_heads, q_length, head_dim] 68 | context_layer = self._merge_heads(output) 69 | 70 | # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 71 | if self.pretraining_tp > 1 and self.slow_but_exact: 72 | slices = self.hidden_size / self.pretraining_tp 73 | output_tensor = torch.zeros_like(context_layer) 74 | for i in range(self.pretraining_tp): 75 | output_tensor = output_tensor + F.linear( 76 | context_layer[:, :, int(i * slices) : int((i + 1) * slices)], 77 | self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], 78 | ) 79 | else: 80 | output_tensor = self.dense(context_layer) 81 | 82 | output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) 83 | 84 | outputs = (output_tensor, present) 85 | if output_attentions: 86 | outputs += (context_layer,) 87 | 88 | return outputs 89 | 90 | 91 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 92 | # requires the attention mask to be the same as the key_padding_mask 93 | def _prepare_attn_mask( 94 | self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int 95 | ) -> torch.BoolTensor: 96 | 97 | return attention_mask 98 | 99 | def replace_bloom_attn_with_flash_attn(): 100 | transformers.models.bloom.modeling_bloom.BloomModel._prepare_attn_mask = ( 101 | _prepare_attn_mask 102 | ) 103 | transformers.models.bloom.modeling_bloom.BloomAttention.forward = forward -------------------------------------------------------------------------------- /src/deita/alignment/train.py: -------------------------------------------------------------------------------- 1 | # This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright: 2 | # 3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from dataclasses import dataclass, field 18 | import json 19 | import pathlib 20 | from typing import Dict, Optional 21 | 22 | import numpy as np 23 | import torch 24 | from torch.utils.data import Dataset 25 | import transformers 26 | from transformers import Trainer 27 | from transformers.trainer_pt_utils import LabelSmoother 28 | from datasets import load_dataset 29 | 30 | from conversation import SeparatorStyle 31 | from conversation import get_conv_template 32 | 33 | from transformers import Trainer 34 | 35 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 36 | 37 | @dataclass 38 | class ModelArguments: 39 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 40 | flash_attn: bool = False 41 | 42 | 43 | @dataclass 44 | class DataArguments: 45 | data_path: str = field( 46 | default=None, metadata={"help": "Path to the training data."} 47 | ) 48 | lazy_preprocess: bool = False 49 | conv_template: str = field(default = "vicuna-1.1") 50 | 51 | 52 | @dataclass 53 | class TrainingArguments(transformers.TrainingArguments): 54 | cache_dir: Optional[str] = field(default=None) 55 | optim: str = field(default="adamw_torch") 56 | model_max_length: int = field( 57 | default=512, 58 | metadata={ 59 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 60 | }, 61 | ) 62 | min_lr: float = field( 63 | default = None 64 | ) 65 | mask_user: bool = field( 66 | default = True 67 | ) 68 | 69 | 70 | local_rank = None 71 | 72 | 73 | def rank0_print(*args): 74 | if local_rank == 0: 75 | print(*args) 76 | 77 | 78 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 79 | """Collects the state dict and dump to disk.""" 80 | state_dict = trainer.model.state_dict() 81 | if trainer.args.should_save: 82 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 83 | del state_dict 84 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 85 | 86 | 87 | def preprocess( 88 | sources, 89 | tokenizer: transformers.PreTrainedTokenizer, 90 | conv_template = "vicuna-1.1", 91 | mask_user = True 92 | ) -> Dict: 93 | 94 | conv = get_conv_template(conv_template) 95 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 96 | 97 | # Apply prompt templates 98 | conversations = [] 99 | for i, source in enumerate(sources): 100 | if roles[source[0]["from"]] != conv.roles[0]: 101 | # Skip the first one if it is not from human 102 | source = source[1:] 103 | 104 | conv.messages = [] 105 | for j, sentence in enumerate(source): 106 | role = roles[sentence["from"]] 107 | # assert role == conv.roles[j % 2], f"{i}" 108 | assert role == conv.roles[j % 2], breakpoint() 109 | conv.append_message(role, sentence["value"]) 110 | conversations.append(conv.get_prompt()) 111 | 112 | # Tokenize conversations 113 | input_ids = tokenizer( 114 | conversations, 115 | return_tensors="pt", 116 | padding="max_length", 117 | max_length=tokenizer.model_max_length, 118 | truncation=True, 119 | ).input_ids 120 | 121 | targets = input_ids.clone() 122 | 123 | assert (conv.sep_style == SeparatorStyle.ADD_COLON_TWO) or (conv.sep_style == SeparatorStyle.CHATML) 124 | 125 | if mask_user: 126 | # Mask targets. Only compute loss on the assistant outputs. 127 | if conv.sep_style == SeparatorStyle.ADD_COLON_TWO: 128 | sep = conv.sep + conv.roles[1] + ": " 129 | for conversation, target in zip(conversations, targets): 130 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 131 | 132 | turns = conversation.split(conv.sep2) 133 | cur_len = 1 134 | target[:cur_len] = IGNORE_TOKEN_ID 135 | # breakpoint() 136 | for i, turn in enumerate(turns): 137 | if turn == "": 138 | break 139 | turn_len = len(tokenizer(turn).input_ids) 140 | 141 | parts = turn.split(sep) 142 | if len(parts) != 2: 143 | break 144 | parts[0] += sep 145 | # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct. 146 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 147 | 148 | # Ignore the user instructions 149 | target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID 150 | cur_len += turn_len 151 | 152 | target[cur_len:] = IGNORE_TOKEN_ID 153 | 154 | if False: # Inspect and check the correctness of masking 155 | z = target.clone() 156 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) 157 | rank0_print(tokenizer.decode(z)) 158 | 159 | 160 | elif conv.sep_style == SeparatorStyle.CHATML: 161 | breakpoint() 162 | sep = conv.sep + conv.roles[1] + "\n" 163 | for conversation, target in zip(conversations, targets): 164 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 165 | 166 | turns = conversation.split(conv.sep) 167 | cur_len = 1 168 | target[:cur_len] = IGNORE_TOKEN_ID 169 | # breakpoint() 170 | for i, turn in enumerate(turns): 171 | if turn == "": 172 | break 173 | turn_len = len(tokenizer(turn).input_ids) 174 | 175 | parts = turn.split(sep) 176 | if len(parts) != 2: 177 | break 178 | parts[0] += sep 179 | # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct. 180 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 181 | 182 | # Ignore the user instructions 183 | target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID 184 | cur_len += turn_len 185 | 186 | target[cur_len:] = IGNORE_TOKEN_ID 187 | 188 | if cur_len < tokenizer.model_max_length: 189 | if cur_len != total_len: 190 | target[:] = IGNORE_TOKEN_ID 191 | rank0_print( 192 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 193 | f" (ignored)" 194 | ) 195 | 196 | return dict( 197 | input_ids=input_ids, 198 | labels=targets, 199 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 200 | ) 201 | 202 | 203 | class SupervisedDataset(Dataset): 204 | """Dataset for supervised fine-tuning.""" 205 | 206 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, conv_template = "vicuna-1.1", mask_user = True): 207 | super(SupervisedDataset, self).__init__() 208 | 209 | rank0_print("Formatting inputs...") 210 | sources = [example["conversations"] for example in raw_data] 211 | data_dict = preprocess(sources, tokenizer, conv_template, mask_user) 212 | 213 | if mask_user: 214 | rank0_print( 215 | f"WARNING: The loss of user prompt will be masked" 216 | ) 217 | else: 218 | rank0_print( 219 | f"WARNING: The loss of user prompt will **NOT** be masked" 220 | ) 221 | 222 | 223 | self.input_ids = data_dict["input_ids"] 224 | self.labels = data_dict["labels"] 225 | self.attention_mask = data_dict["attention_mask"] 226 | 227 | def __len__(self): 228 | return len(self.input_ids) 229 | 230 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 231 | return dict( 232 | input_ids=self.input_ids[i], 233 | labels=self.labels[i], 234 | attention_mask=self.attention_mask[i], 235 | ) 236 | 237 | 238 | class LazySupervisedDataset(Dataset): 239 | """Dataset for supervised fine-tuning.""" 240 | 241 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, conv_template = "vicuna-1.1", mask_user = True): 242 | super(LazySupervisedDataset, self).__init__() 243 | self.tokenizer = tokenizer 244 | 245 | rank0_print("Formatting inputs...Skip in lazy mode") 246 | self.conv_template = conv_template 247 | self.mask_user = mask_user 248 | self.tokenizer = tokenizer 249 | self.raw_data = raw_data 250 | self.cached_data_dict = {} 251 | 252 | if mask_user: 253 | rank0_print( 254 | f"WARNING: The loss of user prompt will be masked" 255 | ) 256 | else: 257 | rank0_print( 258 | f"WARNING: The loss of user prompt will **NOT** be masked" 259 | ) 260 | 261 | def __len__(self): 262 | return len(self.raw_data) 263 | 264 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 265 | if i in self.cached_data_dict: 266 | return self.cached_data_dict[i] 267 | 268 | ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.conv_template, self.mask_user) 269 | ret = dict( 270 | input_ids=ret["input_ids"][0], 271 | labels=ret["labels"][0], 272 | attention_mask=ret["attention_mask"][0], 273 | ) 274 | self.cached_data_dict[i] = ret 275 | 276 | return ret 277 | 278 | 279 | def make_supervised_data_module( 280 | tokenizer: transformers.PreTrainedTokenizer, data_args, mask_user = True 281 | ) -> Dict: 282 | """Make dataset and collator for supervised fine-tuning.""" 283 | conv_template = data_args.conv_template 284 | dataset_cls = ( 285 | LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset 286 | ) 287 | rank0_print("Loading data...") 288 | try: 289 | raw_data = json.load(open(data_args.data_path, "r")) 290 | except FileNotFoundError: 291 | raw_data = load_dataset(data_args.data_path, split = "train") 292 | raw_data = [row for row in raw_data] 293 | 294 | # Split train/eval 295 | np.random.seed(0) 296 | train_raw_data = raw_data 297 | perm = np.random.permutation(len(raw_data)) 298 | split = int(len(perm) * 0.98) 299 | train_indices = perm[:split] 300 | eval_indices = perm[split:] 301 | train_raw_data = [raw_data[i] for i in train_indices] 302 | eval_raw_data = [raw_data[i] for i in eval_indices] 303 | rank0_print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}") 304 | 305 | train_dataset = dataset_cls(train_raw_data, tokenizer=tokenizer, conv_template = conv_template, mask_user = mask_user) 306 | eval_dataset = dataset_cls(eval_raw_data, tokenizer=tokenizer, conv_template = conv_template, mask_user = mask_user) 307 | return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) 308 | 309 | def train(): 310 | global local_rank 311 | 312 | parser = transformers.HfArgumentParser( 313 | (ModelArguments, DataArguments, TrainingArguments) 314 | ) 315 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 316 | training_args.do_eval = False 317 | local_rank = training_args.local_rank 318 | model = transformers.AutoModelForCausalLM.from_pretrained( 319 | model_args.model_name_or_path, 320 | cache_dir=training_args.cache_dir, 321 | use_flash_attention_2 = True 322 | ) 323 | model.config.use_cache = False 324 | tokenizer = transformers.AutoTokenizer.from_pretrained( 325 | model_args.model_name_or_path, 326 | cache_dir=training_args.cache_dir, 327 | model_max_length=training_args.model_max_length, 328 | padding_side="right", 329 | use_fast=False, 330 | ) 331 | tokenizer.pad_token = tokenizer.unk_token 332 | 333 | if "mistral" in model_args.model_name_or_path.lower(): 334 | rank0_print("Mistral with Left Padding Side") 335 | tokenizer.padding_side = "left" 336 | 337 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, mask_user = training_args.mask_user) 338 | 339 | trainer = Trainer( 340 | model=model, tokenizer=tokenizer, args=training_args, **data_module 341 | ) 342 | 343 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 344 | trainer.train(resume_from_checkpoint=True) 345 | else: 346 | trainer.train() 347 | trainer.save_state() 348 | 349 | trainer.save_model(output_dir = training_args.output_dir) 350 | 351 | 352 | if __name__ == "__main__": 353 | train() 354 | -------------------------------------------------------------------------------- /src/deita/alignment/train_scorers.py: -------------------------------------------------------------------------------- 1 | # This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright: 2 | # 3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import os 17 | from pathlib import Path 18 | import sys 19 | 20 | import copy 21 | from dataclasses import dataclass, field 22 | import json 23 | import pathlib 24 | from typing import Dict, Optional, Sequence 25 | 26 | import numpy as np 27 | import torch 28 | from torch.utils.data import Dataset 29 | import transformers 30 | from transformers import Trainer 31 | from transformers.trainer_pt_utils import LabelSmoother 32 | 33 | from conversation import get_conv_template 34 | 35 | 36 | from functools import partial 37 | from datasets import Dataset 38 | 39 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 40 | 41 | @dataclass 42 | class ModelArguments: 43 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 44 | flash_attn: bool = False 45 | 46 | 47 | @dataclass 48 | class DataArguments: 49 | data_path: str = field( 50 | default=None, metadata={"help": "Path to the training data."} 51 | ) 52 | lazy_preprocess: bool = False 53 | conv_template: str = field(default = "vicuna-1.1") 54 | 55 | 56 | @dataclass 57 | class TrainingArguments(transformers.TrainingArguments): 58 | cache_dir: Optional[str] = field(default=None) 59 | optim: str = field(default="adamw_torch") 60 | model_max_length: int = field( 61 | default=512, 62 | metadata={ 63 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 64 | }, 65 | ) 66 | min_lr: float = field( 67 | default = None 68 | ) 69 | mask_user: bool = field( 70 | default = True 71 | ) 72 | 73 | 74 | local_rank = None 75 | 76 | 77 | 78 | def rank0_print(*args): 79 | if local_rank == 0: 80 | print(*args) 81 | 82 | 83 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 84 | """Collects the state dict and dump to disk.""" 85 | state_dict = trainer.model.state_dict() 86 | if trainer.args.should_save: 87 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 88 | del state_dict 89 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 90 | 91 | 92 | def preprocess( 93 | sample, 94 | conv_template, 95 | max_length, 96 | tokenizer, 97 | mask_user = True, 98 | ) -> Dict: 99 | 100 | source = sample["conversations"] 101 | 102 | conv = get_conv_template(conv_template) 103 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 104 | 105 | if roles[source[0]["from"]] != conv.roles[0]: 106 | # Skip the first one if it is not from human 107 | source = source[1:] 108 | 109 | conv.messages = [] 110 | for j, sentence in enumerate(source): 111 | role = roles[sentence["from"]] 112 | # assert role == conv.roles[j % 2], f"{i}" 113 | assert role == conv.roles[j % 2], breakpoint() 114 | conv.append_message(role, sentence["value"]) 115 | conversation = conv.get_prompt() 116 | 117 | input_ids = tokenizer( 118 | conversation, 119 | return_tensors="pt", 120 | max_length=max_length, 121 | truncation=True, 122 | ).input_ids 123 | 124 | input_ids = input_ids.flatten() 125 | 126 | if mask_user: 127 | 128 | targets = input_ids.clone() 129 | 130 | if roles[source[0]["from"]] != conv.roles[0]: 131 | # Skip the first one if it is not from human 132 | source = source[1:] 133 | 134 | conv.messages = [] 135 | for j, sentence in enumerate(source): 136 | 137 | role = roles[sentence["from"]] 138 | # assert role == conv.roles[j % 2], f"{i}" 139 | assert role == conv.roles[j % 2], breakpoint() 140 | 141 | if role == conv.roles[1]: 142 | conv.append_message(role, sentence["value"]) 143 | 144 | if role != conv.roles[1]: 145 | if j == 0: 146 | conv_start_idx = 0 147 | else: 148 | conv_last = conv.get_prompt() 149 | conv_start_idx = tokenizer( 150 | conv_last, 151 | return_tensors="pt", 152 | max_length=max_length, 153 | truncation=True, 154 | ).input_ids.shape[1] 155 | 156 | conv.append_message(role, sentence["value"]) 157 | conv_so_far = conv.get_prompt() 158 | 159 | conv_end_idx = tokenizer( 160 | conv_so_far, 161 | return_tensors="pt", 162 | max_length=max_length, 163 | truncation=True, 164 | ).input_ids.shape[1] 165 | 166 | # conv_end_idx -= 1 # hard offset for llama model 167 | 168 | targets[conv_start_idx:conv_end_idx] = IGNORE_TOKEN_ID 169 | 170 | if conv_end_idx >= max_length: 171 | break 172 | 173 | attention_mask = torch.ones_like(input_ids) 174 | 175 | return dict( 176 | input_ids=input_ids, 177 | labels = targets, 178 | attention_mask = attention_mask, 179 | ) 180 | 181 | 182 | def get_datasets(data, preprocess_func, num_proc): 183 | 184 | conversations = [{"conversations": item["conversations"]} for item in data] 185 | 186 | raw_dataset = Dataset.from_list(conversations) 187 | 188 | tokenized_datasets = raw_dataset.map( 189 | preprocess_func, 190 | batched = False, 191 | num_proc = num_proc, 192 | remove_columns = ["conversations"], 193 | desc = "Tokenizing and reformatting instruction data" 194 | ) 195 | 196 | return tokenized_datasets 197 | 198 | 199 | def make_supervised_data_module( 200 | model, tokenizer: transformers.PreTrainedTokenizer, max_length, fwd_batch_size, data_args, mask_user = True 201 | ) -> Dict: 202 | """Make dataset and collator for supervised fine-tuning.""" 203 | conv_template = data_args.conv_template 204 | rank0_print("Loading data...") 205 | raw_data = json.load(open(data_args.data_path, "r")) 206 | 207 | preprocess_func = partial(preprocess, 208 | conv_template = conv_template, 209 | max_length = max_length, 210 | tokenizer = tokenizer, 211 | mask_user = mask_user) 212 | 213 | train_dataset = get_datasets(raw_data, preprocess_func, 32) 214 | 215 | return dict(train_dataset=train_dataset) 216 | 217 | 218 | def train(): 219 | global local_rank 220 | 221 | parser = transformers.HfArgumentParser( 222 | (ModelArguments, DataArguments, TrainingArguments) 223 | ) 224 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 225 | training_args.do_eval = False 226 | local_rank = training_args.local_rank 227 | world_size = training_args.world_size 228 | model = transformers.AutoModelForCausalLM.from_pretrained( 229 | model_args.model_name_or_path, 230 | cache_dir=training_args.cache_dir, 231 | use_flash_attention_2 = True 232 | ) 233 | model.config.use_cache = False 234 | tokenizer = transformers.AutoTokenizer.from_pretrained( 235 | model_args.model_name_or_path, 236 | cache_dir=training_args.cache_dir, 237 | model_max_length=training_args.model_max_length, 238 | padding_side="right", 239 | use_fast=False, 240 | ) 241 | tokenizer.pad_token = tokenizer.unk_token 242 | 243 | if "mistral" in model_args.model_name_or_path.lower(): 244 | rank0_print("Mistral with Left Padding Side") 245 | tokenizer.padding_side = "left" 246 | 247 | 248 | fwd_batch_size = training_args.per_device_train_batch_size * world_size 249 | 250 | train_dataset = make_supervised_data_module(model = model, 251 | tokenizer=tokenizer, 252 | max_length = training_args.model_max_length, 253 | fwd_batch_size = fwd_batch_size, 254 | data_args=data_args, 255 | mask_user = training_args.mask_user) 256 | 257 | trainer = Trainer( 258 | model=model, tokenizer=tokenizer, args=training_args, **train_dataset 259 | ) 260 | 261 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 262 | trainer.train(resume_from_checkpoint=True) 263 | else: 264 | trainer.train() 265 | trainer.save_state() 266 | trainer.save_model(output_dir = training_args.output_dir) 267 | 268 | 269 | if __name__ == "__main__": 270 | train() 271 | -------------------------------------------------------------------------------- /src/deita/data/sample_ultrafeedback.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | from datasets import load_dataset, Dataset 5 | 6 | outputpath = sys.argv[1] 7 | datanum = sys.argv[2] 8 | 9 | random.seed(42) 10 | 11 | if not os.path.exists(outputpath): 12 | data = load_dataset("HuggingFaceH4/ultrafeedback_binarized", split = "train_prefs") 13 | data.to_json(outputpath) 14 | 15 | sample_indices = random.sample(range(len(data)), int(datanum)) 16 | 17 | sampled_data = data.select(sample_indices) 18 | sampled_data.to_json(outputpath) -------------------------------------------------------------------------------- /src/deita/ds_configs/deepspeed_config_zero2_no_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "bf16": { 12 | "enabled": "auto" 13 | }, 14 | 15 | "optimizer": { 16 | "type": "AdamW", 17 | "params": { 18 | "lr": "auto", 19 | "betas": "auto", 20 | "eps": "auto", 21 | "weight_decay": "auto" 22 | } 23 | }, 24 | 25 | "zero_optimization": { 26 | "stage": 2, 27 | "allgather_partitions": true, 28 | "allgather_bucket_size": 2e8, 29 | "overlap_comm": true, 30 | "reduce_scatter": true, 31 | "reduce_bucket_size": 2e8, 32 | "contiguous_gradients": true 33 | }, 34 | 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 2000, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /src/deita/ds_configs/deepspped_llama_x.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 3, 4 | "offload_optimizer": { 5 | "device": "cpu", 6 | "pin_memory": true 7 | }, 8 | "offload_param": { 9 | "device": "cpu", 10 | "pin_memory": true 11 | }, 12 | "overlap_comm": true, 13 | "contiguous_gradients": true, 14 | "sub_group_size": 0, 15 | "reduce_bucket_size": "auto", 16 | "stage3_prefetch_bucket_size": "auto", 17 | "stage3_param_persistence_threshold": "auto", 18 | "stage3_max_live_parameters": 0, 19 | "stage3_max_reuse_distance": 0, 20 | "stage3_gather_16bit_weights_on_model_save": true 21 | }, 22 | "bf16": { 23 | "enabled": true, 24 | "auto_cast": false, 25 | "loss_scale": 0, 26 | "initial_scale_power": 32, 27 | "loss_scale_window": 1000, 28 | "hysteresis": 2, 29 | "min_loss_scale": 1 30 | }, 31 | "optimizer": { 32 | "type": "AdamW", 33 | "params": { 34 | "lr": "auto", 35 | "betas": "auto", 36 | "eps": "auto", 37 | "weight_decay": "auto" 38 | } 39 | }, 40 | "train_batch_size": "auto", 41 | "train_micro_batch_size_per_gpu": "auto", 42 | "gradient_accumulation_steps": "auto", 43 | "wall_clock_breakdown": false 44 | } 45 | -------------------------------------------------------------------------------- /src/deita/ds_configs/stage3_no_offloading_accelerate.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "zero_optimization": { 6 | "stage": 3, 7 | "overlap_comm": true, 8 | "contiguous_gradients": true, 9 | "sub_group_size": 1e9, 10 | "reduce_bucket_size": "auto", 11 | "stage3_prefetch_bucket_size": "auto", 12 | "stage3_param_persistence_threshold": "auto", 13 | "stage3_max_live_parameters": 1e9, 14 | "stage3_max_reuse_distance": 1e9, 15 | "stage3_gather_16bit_weights_on_model_save": true 16 | }, 17 | "gradient_accumulation_steps": "auto", 18 | "gradient_clipping": "auto", 19 | "steps_per_print": 1e5, 20 | "train_batch_size": "auto", 21 | "train_micro_batch_size_per_gpu": "auto", 22 | "wall_clock_breakdown": false 23 | } -------------------------------------------------------------------------------- /src/deita/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .embed_pipeline import EmbedPipeline 2 | from .score_pipeline import ScorePipeline 3 | from .filter_pipeline import FilterPipeline 4 | from .base import PipelineRegistry 5 | from typing import Callable 6 | 7 | PipelineRegistry.register("score_pipeline", ScorePipeline) 8 | PipelineRegistry.register("embed_pipeline", EmbedPipeline) 9 | PipelineRegistry.register("filter_pipeline", FilterPipeline) 10 | 11 | class Pipeline: 12 | 13 | def __new__(cls, name, **kwargs) -> Callable: 14 | 15 | PipelineClass = PipelineRegistry.get_pipeline(name) 16 | return PipelineClass(name, **kwargs) -------------------------------------------------------------------------------- /src/deita/pipeline/base.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import Any, Dict, List, Optional, Tuple, Union, Callable 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | class BasePipeline: 8 | 9 | def __init__(self, name: str, data_path: str, **kwargs) -> None: 10 | 11 | self.name = name 12 | self.data_path = data_path 13 | 14 | def _load_data(self, data_path: str) -> None: 15 | 16 | """ 17 | Load data from data_path. 18 | 19 | data_path: str - path to json data file. 20 | """ 21 | 22 | try: 23 | with open(data_path, "r") as f: 24 | data = json.load(f) 25 | except json.JSONDecodeError: 26 | with open(data_path, "r") as f: 27 | data = [json.loads(line) for line in f] 28 | 29 | return data 30 | 31 | def _load_other_data(self, other_data_path: str) -> None: 32 | raise NotImplementedError 33 | 34 | def _save_data(self, data_path: str, data_format: str) -> None: 35 | raise NotImplementedError 36 | 37 | def _preprocess(self, json_data, other_data) -> None: 38 | raise NotImplementedError 39 | 40 | def _forward(self, preprocessed_data) -> None: 41 | raise NotImplementedError 42 | 43 | def run(self) -> None: 44 | 45 | json_data = self._load_data(self.data_path) 46 | 47 | other_data = None 48 | if hasattr(self, "other_data_path"): 49 | other_data = self._load_other_data(self.other_data_path) 50 | 51 | preprocessed_data = self._preprocess(json_data, other_data) 52 | results = self._forward(preprocessed_data) 53 | self._save_data(json_data, results) 54 | logger.info(f"Pipeline {self.name} run complete.") 55 | 56 | class PipelineRegistry: 57 | 58 | registry = {} 59 | 60 | @classmethod 61 | def register(cls, name: str, pipline_class: Callable): 62 | 63 | if name in cls.registry: 64 | raise ValueError(f"Pipeline {name} already registered.") 65 | cls.registry[name] = pipline_class 66 | 67 | @classmethod 68 | def get_pipeline(cls, name: str): 69 | 70 | if name not in cls.registry: 71 | raise ValueError(f"Pipeline {name} not registered.") 72 | return cls.registry[name] 73 | 74 | -------------------------------------------------------------------------------- /src/deita/pipeline/embed_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | from deita.pipeline.base import BasePipeline 4 | from deita.selection.embedder import CLM_Embedder 5 | import logging 6 | import pandas 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | class EmbedPipeline(BasePipeline): 11 | 12 | def __init__(self, name: str, data_path: str, **kwargs) -> None: 13 | 14 | self.name = name 15 | self.data_path = data_path 16 | self.is_compression = kwargs.get("is_compression", False) 17 | 18 | self.data_format = "sharegpt" # only support sharegpt for now 19 | self.output_path = kwargs.get("output_path") 20 | 21 | if not os.path.exists(self.output_path): 22 | os.makedirs(os.path.dirname(self.output_path), exist_ok=True) 23 | 24 | self.embedder = CLM_Embedder(**kwargs) 25 | 26 | def _preprocess(self, json_data, other_data) -> List: 27 | return json_data 28 | 29 | def _forward(self, preprocessed_data) -> List: 30 | 31 | all_embeddings_list = self.embedder.encode_samples(preprocessed_data) 32 | 33 | logger.info(f"{len(all_embeddings_list)}") 34 | logger.info("Finished embedding") 35 | 36 | return all_embeddings_list 37 | 38 | def _save_data(self, json_data: List, results: List) -> None: 39 | 40 | # We use dataframe to save the results 41 | df = pandas.DataFrame(results) 42 | 43 | if self.embedder.accelerator.is_main_process: 44 | df.sort_values(by = "idx", inplace = True) 45 | df.reset_index(drop = True, inplace = True) 46 | 47 | if not self.is_compression: 48 | df.to_pickle(self.output_path) 49 | logger.info(f"Saved pickle to {self.output_path}") 50 | else: 51 | df.to_pickle(self.output_path, "zip") 52 | logger.info(f"Saved pickle to {self.output_path} with zip compression") 53 | 54 | 55 | -------------------------------------------------------------------------------- /src/deita/pipeline/filter_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | from typing import List 5 | from deita.pipeline.base import BasePipeline 6 | from deita.pipeline.utils import sort_key_split 7 | from deita.selection.filter import Combined_Filter 8 | import logging 9 | import pandas 10 | import numpy as np 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | class FilterPipeline(BasePipeline): 15 | 16 | def __init__(self, name: str, data_path: str, **kwargs) -> None: 17 | 18 | self.name = name 19 | self.data_path = data_path 20 | self.other_data_path = kwargs.get("other_data_path") 21 | self.is_compression = kwargs.get("is_compression", False) 22 | 23 | self.data_format = "sharegpt" # only support sharegpt for now 24 | self.output_path = kwargs.get("output_path") 25 | self.sort_key = kwargs.get("sort_key") 26 | self.sort_key = sort_key_split(self.sort_key) 27 | kwargs["sort_key"] = self.sort_key 28 | 29 | if not os.path.exists(self.output_path): 30 | os.makedirs(os.path.dirname(self.output_path), exist_ok=True) 31 | 32 | self.filter = Combined_Filter(**kwargs) 33 | 34 | def _load_other_data(self, other_data_path: str) -> None: 35 | """ 36 | Load Embedding Data 37 | """ 38 | 39 | if self.is_compression: 40 | embedding_data = pd.read_pickle(other_data_path, "zip") 41 | else: 42 | embedding_data = pd.read_pickle(other_data_path) 43 | 44 | return embedding_data 45 | 46 | def _preprocess(self, json_data, other_data) -> List: 47 | 48 | """ 49 | json_data: List - data to be filtered 50 | other_data: pd.DataFrame - embedding data 51 | """ 52 | 53 | if isinstance(other_data, np.ndarray): 54 | df_data = pd.DataFrame([{"embedding": other_data[i]} for i in range(other_data.shape[0])]) 55 | elif isinstance(other_data, pd.DataFrame): 56 | df_data = other_data 57 | else: 58 | raise ValueError("other_data must be either np.array or pd.DataFrame") 59 | 60 | if "idx" not in df_data.columns: 61 | df_data["idx"] = df_data.index 62 | 63 | df_json = pd.DataFrame(json_data) 64 | 65 | for sk in self.sort_key: 66 | df_data[sk] = df_json[sk].tolist() 67 | 68 | return df_data 69 | 70 | def _forward(self, preprocessed_data) -> List: 71 | 72 | selected_data = self.filter.filter(preprocessed_data) 73 | selected_data_indices = selected_data["idx"].tolist() 74 | 75 | logger.info(f"Selected Data Number: {len(selected_data_indices)}") 76 | logger.info("Finished Combined Selection") 77 | 78 | return selected_data_indices 79 | 80 | def _save_data(self, json_data: List, results: List) -> None: 81 | 82 | selected_data = [] 83 | 84 | for idx in results: 85 | selected_data.append(json_data[idx]) 86 | 87 | with open(self.output_path, "w") as f: 88 | json.dump(selected_data, f, indent=2, ensure_ascii=False) 89 | 90 | 91 | -------------------------------------------------------------------------------- /src/deita/pipeline/score_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from deita.pipeline.base import BasePipeline 4 | from deita.selection.scorer import Llama_Scorer, Mistral_Scorer 5 | from typing import Any, Dict, List, Optional, Tuple, Union 6 | from deita.pipeline.utils import load_data 7 | import logging 8 | from tqdm import tqdm 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | class ScorePipeline(BasePipeline): 13 | 14 | def __init__(self, name: str, data_path: str, **kwargs) -> None: 15 | 16 | self.name = name 17 | self.data_path = data_path 18 | 19 | self.data_format = "sharegpt" # only support sharegpt for now 20 | 21 | scorer = kwargs.get("scorer") 22 | is_vllm = kwargs.get("is_vllm") 23 | scorer_name_or_path = kwargs.get("scorer_name_or_path") 24 | self.score_type = kwargs.get("score_type") 25 | 26 | if scorer == "llama": 27 | self.model = Llama_Scorer(scorer_name_or_path, is_vllm) 28 | elif scorer == "mistral": 29 | self.model = Mistral_Scorer(scorer_name_or_path, is_vllm) 30 | 31 | self.output_path = kwargs.get("output_path") 32 | 33 | if not os.path.exists(self.output_path): 34 | os.makedirs(os.path.dirname(self.output_path), exist_ok=True) 35 | 36 | def _load_sharegpt(self, data: str) -> List: 37 | 38 | preprocessed_data = [] 39 | 40 | for sample_id, item in enumerate(data): 41 | 42 | preprocessed_item = [] 43 | 44 | for idx in range(len(item["conversations"])): 45 | 46 | if idx % 2 != 0: 47 | continue 48 | 49 | if idx != len(item["conversations"]) - 1: 50 | preprocessed_item.append({"instruction": item["conversations"][idx]["value"], "response": item["conversations"][idx+1]["value"]}) 51 | else: 52 | preprocessed_item.append({"instruction": item["conversations"][idx]["value"], "response": ""}) 53 | 54 | preprocessed_data.append({"conversations": preprocessed_item, "n_conv": len(preprocessed_item)}) 55 | 56 | return preprocessed_data 57 | 58 | def _inject_sharegpt(self, json_data: List, results: List) -> None: 59 | 60 | for sample_id in range(len(json_data)): 61 | 62 | json_data[sample_id][f"{self.score_type}_scores"] = [] 63 | 64 | for item in results[sample_id]["conversations"]: 65 | json_data[sample_id][f"{self.score_type}_scores"].append(float(item[f"{self.score_type}_score"])) 66 | 67 | 68 | def _preprocess(self, json_data, other_data) -> List: 69 | 70 | if self.data_format == "sharegpt": 71 | preprocessed_data = self._load_sharegpt(json_data) 72 | else: 73 | raise ValueError(f"Data format {self.data_format} not supported.") 74 | 75 | return preprocessed_data 76 | 77 | def _forward(self, preprocessed_data) -> List: 78 | 79 | for convs in tqdm(preprocessed_data, total = len(preprocessed_data)): 80 | 81 | for conv in convs["conversations"]: 82 | 83 | if self.score_type.lower() == "complexity": 84 | score = self.model.infer_complexity(conv["instruction"]) 85 | elif self.score_type.lower() == "quality": 86 | score = self.model.infer_quality(conv["instruction"], conv["response"]) 87 | else: 88 | raise ValueError(f"Score type {self.score_type} not supported.") 89 | 90 | conv[f"{self.score_type}_score"] = score 91 | 92 | return preprocessed_data 93 | 94 | def _save_data(self, json_data: List, results: List) -> None: 95 | 96 | if self.data_format == "sharegpt": 97 | self._inject_sharegpt(json_data, results) 98 | else: 99 | raise ValueError(f"Data format {self.data_format} not supported.") 100 | 101 | with open(self.output_path, "w") as f: 102 | json.dump(json_data, f, indent=2, ensure_ascii=False) 103 | logger.info(f"Saved results to {self.output_path}.") -------------------------------------------------------------------------------- /src/deita/pipeline/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | 4 | INSTRUCTION_ONLY_TYPE = ["complexity"] 5 | INSTRUCTION_RESPONSE_TYPE = ["quality"] 6 | 7 | def sort_key_split(sort_key: str) -> List: 8 | """ 9 | Split sort_key into a list of sort keys. 10 | 11 | sort_key: str - sort key to split. 12 | """ 13 | if "," in sort_key: 14 | return sort_key.split(",") 15 | elif "." in sort_key: 16 | return sort_key.split(".") 17 | elif "+" in sort_key: 18 | return sort_key.split("+") 19 | else: 20 | raise ValueError("sort_key must be a string with delimiter ',' or '.' or '+'.") 21 | 22 | def load_data(self, data_path: str) -> None: 23 | 24 | """ 25 | Load data from data_path. 26 | 27 | data_path: str - path to json data file. 28 | """ 29 | 30 | try: 31 | with open(data_path, "r") as f: 32 | data = json.load(f) 33 | except json.JSONDecodeError: 34 | with open(data_path, "r") as f: 35 | data = [json.loads(line) for line in f] 36 | 37 | return data -------------------------------------------------------------------------------- /src/deita/selection/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/deita/selection/embedder/__init__.py: -------------------------------------------------------------------------------- 1 | from .clm_embedder import CLM_Embedder 2 | 3 | __all__ = ["CLM_Embedder"] -------------------------------------------------------------------------------- /src/deita/selection/embedder/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch 4 | from datasets import Dataset 5 | from transformers import ( 6 | AutoModelForCausalLM, 7 | AutoTokenizer, 8 | ) 9 | from deita.selection.embedder.utils import batchlize 10 | from transformers.trainer_pt_utils import LabelSmoother 11 | from accelerate import Accelerator 12 | 13 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 14 | 15 | 16 | class Embedder: 17 | 18 | def __init__(self, model_name_or_path, **kwargs) -> None: 19 | 20 | self.compute_dtype = ( 21 | torch.float16 22 | if kwargs.get('fp16', False) 23 | else (torch.bfloat16 if kwargs.get("bfloat16", False) else torch.float32) 24 | ) 25 | 26 | self.model_name_or_path = model_name_or_path 27 | self.max_length = kwargs.get('max_length') 28 | self.use_flash_attention = kwargs.get('use_flash_attention') 29 | self.batch_size_per_device = kwargs.get('batch_size_per_device') 30 | self.conv_template = kwargs.get('conv_template') 31 | self.only_answer = kwargs.get('only_answer') 32 | self.random_shuffle = kwargs.get('random_shuffle') 33 | 34 | self.local_rank = int(os.getenv("LOCAL_RANK", "0")) 35 | self.world_size = int(os.getenv("WORLD_SIZE", "1")) 36 | torch.cuda.set_device(self.local_rank) # NOTE: cpu-only machine will have error 37 | 38 | self.accelerator = Accelerator() 39 | self.accelerator.wait_for_everyone() 40 | 41 | batch_size = self.batch_size_per_device * self.world_size 42 | self.minibatch_size = batch_size 43 | 44 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, 45 | model_max_length = self.max_length, 46 | padding_side = "right", 47 | use_fast = False) 48 | 49 | if "mistral" in self.model_name_or_path: 50 | self.tokenizer.padding_side = "left" 51 | 52 | self.model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, 53 | torch_dtype = self.compute_dtype, 54 | use_flash_attention_2 = self.use_flash_attention) 55 | 56 | if self.tokenizer.pad_token is None: 57 | self.tokenizer.pad_token = self.tokenizer.unk_token 58 | 59 | def rank0_print(self, *args): 60 | if self.local_rank == 0: 61 | print(*args) 62 | 63 | def compute_length(self, conversations: list, cnt_field = "response"): 64 | 65 | all_lengths = [] 66 | for conv in conversations: 67 | 68 | cur_length = 0 69 | 70 | for i, c in enumerate(conv): 71 | if cnt_field == "response": 72 | if i % 2 == 1: 73 | cur_length += len(c["value"]) 74 | elif cnt_field == "instruction": 75 | if i % 2 == 0: 76 | cur_length += len(c["value"]) 77 | else: 78 | cur_length += len(c["value"]) 79 | 80 | all_lengths.append(cur_length) 81 | 82 | return all_lengths 83 | 84 | def create_databuffer(self, conversations: list, sort_by_length = False): 85 | 86 | all_lengths = self.compute_length(conversations) 87 | 88 | dataset_size = len(conversations) 89 | dataset_buf = [] 90 | for idx in range(dataset_size): 91 | 92 | dataset_buf.append({ 93 | "conversations": conversations[idx], 94 | "specific_length": all_lengths[idx], 95 | "input_idx": idx 96 | }) 97 | 98 | if sort_by_length: 99 | dataset_buf = sorted(dataset_buf, key = lambda x: x["specific_length"]) 100 | 101 | return dataset_buf, dataset_size 102 | 103 | def create_dataloader(self, dataset_buf: Dataset): 104 | 105 | dataloader = batchlize( 106 | dataset_buf, 107 | self.minibatch_size, 108 | self.random_shuffle, 109 | sort = self.minibatch_size > 1 110 | ) 111 | 112 | print(f"Successfully create dataloader with size {len(dataloader)},batch_size {self.minibatch_size}.") 113 | 114 | return dataloader 115 | 116 | def probe_samples(self, model, data: list): 117 | 118 | raise NotImplementedError 119 | 120 | def collect_grad(self): 121 | 122 | raise NotImplementedError -------------------------------------------------------------------------------- /src/deita/selection/embedder/clm_embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | 4 | from tqdm import tqdm 5 | import torch 6 | from datasets import Dataset 7 | import torch.distributed as dist 8 | from deita.selection.embedder.base import Embedder 9 | from deita.selection.embedder.utils import DataCollatorForSupervisedDataset, preprocess 10 | 11 | import logging 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | class CLM_Embedder(Embedder): 16 | 17 | def __init__(self, model_name_or_path, **kwargs): 18 | super().__init__(model_name_or_path, **kwargs) 19 | 20 | def encode_samples(self, data): 21 | 22 | conversations = [item["conversations"] for item in data] 23 | dataset_buf, data_size = self.create_databuffer(conversations, sort_by_length = True) 24 | raw_dataset = Dataset.from_list(dataset_buf) 25 | 26 | preprocess_func = partial(preprocess, 27 | conv_template = self.conv_template, 28 | only_answer = self.only_answer, 29 | max_length = self.max_length, 30 | tokenizer = self.tokenizer) 31 | 32 | with self.accelerator.main_process_first(): 33 | tokenized_datasets = raw_dataset.map( 34 | preprocess_func, 35 | batched = True, 36 | num_proc = 32, 37 | remove_columns = ["conversations", "specific_length"], 38 | desc = "Tokenizing and reformatting instruction data" 39 | ) 40 | 41 | data_collator = DataCollatorForSupervisedDataset(tokenizer = self.tokenizer) 42 | dataloader = torch.utils.data.DataLoader(tokenized_datasets, batch_size = self.batch_size_per_device, collate_fn = data_collator) 43 | 44 | model, dataloader = self.accelerator.prepare(self.model, dataloader) 45 | 46 | all_embeddings_list = [] 47 | 48 | total_samples = len(tokenized_datasets) 49 | total_batches = len(dataloader) 50 | last_batch_size = total_samples % self.minibatch_size if total_samples % self.minibatch_size != 0 else self.minibatch_size 51 | 52 | for b_idx, batch in enumerate(tqdm(dataloader, total = len(tokenized_datasets) // self.minibatch_size, disable = not self.accelerator.is_local_main_process)): 53 | 54 | model.eval() 55 | 56 | batch_idx = batch["idx"] 57 | attention_mask = batch["attention_mask"] 58 | 59 | outputs = model(input_ids = batch["input_ids"], attention_mask = batch["attention_mask"], output_hidden_states = True) 60 | 61 | seq_len = attention_mask.sum(1, keepdim = True) 62 | 63 | if self.tokenizer.padding_side == "right": 64 | last_hidden_state = outputs.hidden_states[-1][torch.arange(seq_len.size(0))[:, None], seq_len - 1] 65 | elif self.tokenizer.padding_side == "left": 66 | last_hidden_state = outputs.hidden_states[-1][:, -1] 67 | else: 68 | raise ValueError("Invalid padding strategy") 69 | 70 | sample_idx = batch_idx.tolist() 71 | sample_dict = [{"embedding": lst_hs, "idx": s_id} for lst_hs, s_id in zip(last_hidden_state.tolist(), sample_idx)] 72 | 73 | if(self.world_size > 1): 74 | all_process_embeddings = [[] for _ in range(self.world_size)] 75 | dist.gather_object(sample_dict, all_process_embeddings if dist.get_rank() == 0 else None, dst=0) 76 | else: 77 | all_process_embeddings = [sample_dict] 78 | 79 | if self.accelerator.is_local_main_process: 80 | if b_idx == total_batches - 1: 81 | for process_list in all_process_embeddings[:last_batch_size]: 82 | all_embeddings_list.extend(process_list) 83 | else: 84 | for process_list in all_process_embeddings: 85 | all_embeddings_list.extend(process_list) 86 | 87 | return all_embeddings_list -------------------------------------------------------------------------------- /src/deita/selection/embedder/conversation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Conversation prompt templates. 3 | """ 4 | 5 | import dataclasses 6 | from enum import auto, IntEnum 7 | from typing import List, Any, Dict 8 | 9 | 10 | class SeparatorStyle(IntEnum): 11 | """Separator styles.""" 12 | 13 | ADD_COLON_SINGLE = auto() 14 | ADD_COLON_TWO = auto() 15 | ADD_COLON_SPACE_SINGLE = auto() 16 | NO_COLON_SINGLE = auto() 17 | NO_COLON_TWO = auto() 18 | ADD_NEW_LINE_SINGLE = auto() 19 | LLAMA2 = auto() 20 | CHATGLM = auto() 21 | CHATML = auto() 22 | CHATINTERN = auto() 23 | DOLLY = auto() 24 | RWKV = auto() 25 | PHOENIX = auto() 26 | ROBIN = auto() 27 | 28 | 29 | @dataclasses.dataclass 30 | class Conversation: 31 | """A class that manages prompt templates and keeps all conversation history.""" 32 | 33 | # The name of this template 34 | name: str 35 | # The system prompt 36 | system: str 37 | # Two roles 38 | roles: List[str] 39 | # All messages. Each item is (role, message). 40 | messages: List[List[str]] 41 | # The number of few shot examples 42 | offset: int 43 | # Separators 44 | sep_style: SeparatorStyle 45 | sep: str 46 | sep2: str = None 47 | # Stop criteria (the default one is EOS token) 48 | stop_str: str = None 49 | # Stops generation if meeting any token in this list 50 | stop_token_ids: List[int] = None 51 | 52 | def get_prompt(self) -> str: 53 | """Get the prompt for generation.""" 54 | if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: 55 | ret = self.system + self.sep 56 | for role, message in self.messages: 57 | if message: 58 | ret += role + ": " + message + self.sep 59 | else: 60 | ret += role + ":" 61 | return ret 62 | elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: 63 | seps = [self.sep, self.sep2] 64 | ret = self.system + seps[0] 65 | for i, (role, message) in enumerate(self.messages): 66 | if message: 67 | ret += role + ": " + message + seps[i % 2] 68 | else: 69 | ret += role + ":" 70 | return ret 71 | elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: 72 | ret = self.system + self.sep 73 | for role, message in self.messages: 74 | if message: 75 | ret += role + ": " + message + self.sep 76 | else: 77 | ret += role + ": " # must be end with a space 78 | return ret 79 | elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: 80 | ret = "" if self.system == "" else self.system + self.sep 81 | for role, message in self.messages: 82 | if message: 83 | ret += role + "\n" + message + self.sep 84 | else: 85 | ret += role + "\n" 86 | return ret 87 | elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: 88 | ret = self.system 89 | for role, message in self.messages: 90 | if message: 91 | ret += role + message + self.sep 92 | else: 93 | ret += role 94 | return ret 95 | elif self.sep_style == SeparatorStyle.NO_COLON_TWO: 96 | seps = [self.sep, self.sep2] 97 | ret = self.system 98 | for i, (role, message) in enumerate(self.messages): 99 | if message: 100 | ret += role + message + seps[i % 2] 101 | else: 102 | ret += role 103 | return ret 104 | elif self.sep_style == SeparatorStyle.RWKV: 105 | ret = self.system 106 | for i, (role, message) in enumerate(self.messages): 107 | if message: 108 | ret += ( 109 | role 110 | + ": " 111 | + message.replace("\r\n", "\n").replace("\n\n", "\n") 112 | ) 113 | ret += "\n\n" 114 | else: 115 | ret += role + ":" 116 | return ret 117 | elif self.sep_style == SeparatorStyle.LLAMA2: 118 | seps = [self.sep, self.sep2] 119 | ret = "" 120 | for i, (role, message) in enumerate(self.messages): 121 | if message: 122 | if i == 0: 123 | ret += self.system + message 124 | else: 125 | ret += role + " " + message + seps[i % 2] 126 | else: 127 | ret += role 128 | return ret 129 | elif self.sep_style == SeparatorStyle.CHATGLM: 130 | # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 131 | # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 132 | round_add_n = 1 if self.name == "chatglm2" else 0 133 | if self.system: 134 | ret = self.system + self.sep 135 | else: 136 | ret = "" 137 | 138 | for i, (role, message) in enumerate(self.messages): 139 | if i % 2 == 0: 140 | ret += f"[Round {i//2 + round_add_n}]{self.sep}" 141 | 142 | if message: 143 | ret += f"{role}:{message}{self.sep}" 144 | else: 145 | ret += f"{role}:" 146 | return ret 147 | elif self.sep_style == SeparatorStyle.CHATML: 148 | ret = "" if self.system == "" else self.system + self.sep + "\n" 149 | for role, message in self.messages: 150 | if message: 151 | ret += role + "\n" + message + self.sep + "\n" 152 | else: 153 | ret += role + "\n" 154 | return ret 155 | elif self.sep_style == SeparatorStyle.CHATINTERN: 156 | # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 157 | seps = [self.sep, self.sep2] 158 | ret = self.system 159 | for i, (role, message) in enumerate(self.messages): 160 | if i % 2 == 0: 161 | ret += "" 162 | if message: 163 | ret += role + ":" + message + seps[i % 2] + "\n" 164 | else: 165 | ret += role + ":" 166 | return ret 167 | elif self.sep_style == SeparatorStyle.DOLLY: 168 | seps = [self.sep, self.sep2] 169 | ret = self.system 170 | for i, (role, message) in enumerate(self.messages): 171 | if message: 172 | ret += role + ":\n" + message + seps[i % 2] 173 | if i % 2 == 1: 174 | ret += "\n\n" 175 | else: 176 | ret += role + ":\n" 177 | return ret 178 | elif self.sep_style == SeparatorStyle.PHOENIX: 179 | ret = self.system 180 | for role, message in self.messages: 181 | if message: 182 | ret += role + ": " + "" + message + "" 183 | else: 184 | ret += role + ": " + "" 185 | return ret 186 | elif self.sep_style == SeparatorStyle.ROBIN: 187 | ret = self.system + self.sep 188 | for role, message in self.messages: 189 | if message: 190 | ret += role + ":\n" + message + self.sep 191 | else: 192 | ret += role + ":\n" 193 | return ret 194 | else: 195 | raise ValueError(f"Invalid style: {self.sep_style}") 196 | 197 | def append_message(self, role: str, message: str): 198 | """Append a new message.""" 199 | self.messages.append([role, message]) 200 | 201 | def update_last_message(self, message: str): 202 | """Update the last output. 203 | 204 | The last message is typically set to be None when constructing the prompt, 205 | so we need to update it in-place after getting the response from a model. 206 | """ 207 | self.messages[-1][1] = message 208 | 209 | def to_gradio_chatbot(self): 210 | """Convert the conversation to gradio chatbot format.""" 211 | ret = [] 212 | for i, (role, msg) in enumerate(self.messages[self.offset :]): 213 | if i % 2 == 0: 214 | ret.append([msg, None]) 215 | else: 216 | ret[-1][-1] = msg 217 | return ret 218 | 219 | def to_openai_api_messages(self): 220 | """Convert the conversation to OpenAI chat completion format.""" 221 | ret = [{"role": "system", "content": self.system}] 222 | 223 | for i, (_, msg) in enumerate(self.messages[self.offset :]): 224 | if i % 2 == 0: 225 | ret.append({"role": "user", "content": msg}) 226 | else: 227 | if msg is not None: 228 | ret.append({"role": "assistant", "content": msg}) 229 | return ret 230 | 231 | def copy(self): 232 | return Conversation( 233 | name=self.name, 234 | system=self.system, 235 | roles=self.roles, 236 | messages=[[x, y] for x, y in self.messages], 237 | offset=self.offset, 238 | sep_style=self.sep_style, 239 | sep=self.sep, 240 | sep2=self.sep2, 241 | stop_str=self.stop_str, 242 | stop_token_ids=self.stop_token_ids, 243 | ) 244 | 245 | def dict(self): 246 | return { 247 | "template_name": self.name, 248 | "system": self.system, 249 | "roles": self.roles, 250 | "messages": self.messages, 251 | "offset": self.offset, 252 | } 253 | 254 | 255 | # A global registry for all conversation templates 256 | conv_templates: Dict[str, Conversation] = {} 257 | 258 | 259 | def register_conv_template(template: Conversation, override: bool = False): 260 | """Register a new conversation template.""" 261 | if not override: 262 | assert ( 263 | template.name not in conv_templates 264 | ), f"{template.name} has been registered." 265 | 266 | conv_templates[template.name] = template 267 | 268 | 269 | def get_conv_template(name: str) -> Conversation: 270 | """Get a conversation template.""" 271 | return conv_templates[name].copy() 272 | 273 | 274 | # A template with a one-shot conversation example 275 | register_conv_template( 276 | Conversation( 277 | name="one_shot", 278 | system="A chat between a curious human and an artificial intelligence assistant. " 279 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 280 | roles=("Human", "Assistant"), 281 | messages=( 282 | ( 283 | "Human", 284 | "Got any creative ideas for a 10 year old’s birthday?", 285 | ), 286 | ( 287 | "Assistant", 288 | """Of course! Here are some creative ideas for a 10-year-old's birthday party: 289 | 1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises. 290 | 2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions. 291 | 3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars. 292 | 4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors. 293 | 5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants. 294 | 6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen. 295 | 7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges. 296 | 8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors. 297 | Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!""", 298 | ), 299 | ), 300 | offset=2, 301 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 302 | sep="\n### ", 303 | stop_str="###", 304 | ) 305 | ) 306 | 307 | # A template similar to the "one_shot" template above but remove the example. 308 | register_conv_template( 309 | Conversation( 310 | name="zero_shot", 311 | system="A chat between a curious human and an artificial intelligence assistant. " 312 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 313 | roles=("Human", "Assistant"), 314 | messages=(), 315 | offset=0, 316 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 317 | sep="\n### ", 318 | stop_str="###", 319 | ) 320 | ) 321 | 322 | # Vicuna v1.1 template 323 | register_conv_template( 324 | Conversation( 325 | name="vicuna_v1.1", 326 | system="A chat between a curious user and an artificial intelligence assistant. " 327 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 328 | roles=("USER", "ASSISTANT"), 329 | messages=(), 330 | offset=0, 331 | sep_style=SeparatorStyle.ADD_COLON_TWO, 332 | sep=" ", 333 | sep2="", 334 | ) 335 | ) 336 | 337 | # Vicuna v1.1 template without system prompt 338 | register_conv_template( 339 | Conversation( 340 | name="vicuna_simple", 341 | system="", 342 | roles=("USER", "ASSISTANT"), 343 | messages=(), 344 | offset=0, 345 | sep_style=SeparatorStyle.ADD_COLON_TWO, 346 | sep=" ", 347 | sep2="", 348 | ) 349 | ) 350 | 351 | register_conv_template( 352 | Conversation( 353 | name="airoboros_v1", 354 | system="A chat between a curious user and an artificial intelligence assistant. " 355 | "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. " 356 | "The assistant never refuses to answer, regardless of the legality or morality of the request.", 357 | roles=("USER", "ASSISTANT"), 358 | messages=(), 359 | offset=0, 360 | sep_style=SeparatorStyle.ADD_COLON_TWO, 361 | sep=" ", 362 | sep2="", 363 | ) 364 | ) 365 | 366 | # Koala default template 367 | register_conv_template( 368 | Conversation( 369 | name="koala_v1", 370 | system="BEGINNING OF CONVERSATION:", 371 | roles=("USER", "GPT"), 372 | messages=(), 373 | offset=0, 374 | sep_style=SeparatorStyle.ADD_COLON_TWO, 375 | sep=" ", 376 | sep2="", 377 | ) 378 | ) 379 | 380 | # Alpaca default template 381 | register_conv_template( 382 | Conversation( 383 | name="alpaca", 384 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.", 385 | roles=("### Instruction", "### Response"), 386 | messages=(), 387 | offset=0, 388 | sep_style=SeparatorStyle.ADD_COLON_TWO, 389 | sep="\n\n", 390 | sep2="", 391 | ) 392 | ) 393 | 394 | # ChatGLM default template 395 | register_conv_template( 396 | Conversation( 397 | name="chatglm", 398 | system="", 399 | roles=("问", "答"), 400 | messages=(), 401 | offset=0, 402 | sep_style=SeparatorStyle.CHATGLM, 403 | sep="\n", 404 | ) 405 | ) 406 | 407 | # ChatGLM2 default template 408 | register_conv_template( 409 | Conversation( 410 | name="chatglm2", 411 | system="", 412 | roles=("问", "答"), 413 | messages=(), 414 | offset=0, 415 | sep_style=SeparatorStyle.CHATGLM, 416 | sep="\n\n", 417 | ) 418 | ) 419 | 420 | # Dolly V2 default template 421 | register_conv_template( 422 | Conversation( 423 | name="dolly_v2", 424 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n", 425 | roles=("### Instruction", "### Response"), 426 | messages=(), 427 | offset=0, 428 | sep_style=SeparatorStyle.DOLLY, 429 | sep="\n\n", 430 | sep2="### End", 431 | ) 432 | ) 433 | 434 | # OpenAssistant Pythia default template 435 | register_conv_template( 436 | Conversation( 437 | name="oasst_pythia", 438 | system="", 439 | roles=("<|prompter|>", "<|assistant|>"), 440 | messages=(), 441 | offset=0, 442 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 443 | sep="<|endoftext|>", 444 | ) 445 | ) 446 | 447 | # OpenAssistant default template 448 | register_conv_template( 449 | Conversation( 450 | name="oasst_llama", 451 | system="", 452 | roles=("<|prompter|>", "<|assistant|>"), 453 | messages=(), 454 | offset=0, 455 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 456 | sep="", 457 | ) 458 | ) 459 | 460 | # Tulu default template 461 | register_conv_template( 462 | Conversation( 463 | name="tulu", 464 | system="", 465 | roles=("<|user|>", "<|assistant|>"), 466 | messages=(), 467 | offset=0, 468 | sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, 469 | sep="\n", 470 | ) 471 | ) 472 | 473 | # StableLM Alpha default template 474 | register_conv_template( 475 | Conversation( 476 | name="stablelm", 477 | system="""<|SYSTEM|># StableLM Tuned (Alpha version) 478 | - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. 479 | - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. 480 | - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. 481 | - StableLM will refuse to participate in anything that could harm a human. 482 | """, 483 | roles=("<|USER|>", "<|ASSISTANT|>"), 484 | messages=(), 485 | offset=0, 486 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 487 | sep="", 488 | stop_token_ids=[50278, 50279, 50277, 1, 0], 489 | ) 490 | ) 491 | 492 | # Baize default template 493 | register_conv_template( 494 | Conversation( 495 | name="baize", 496 | system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n", 497 | roles=("[|Human|]", "[|AI|]"), 498 | messages=( 499 | ("[|Human|]", "Hello!"), 500 | ("[|AI|]", "Hi!"), 501 | ), 502 | offset=2, 503 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 504 | sep="\n", 505 | stop_str="[|Human|]", 506 | ) 507 | ) 508 | 509 | # RWKV-4-Raven default template 510 | register_conv_template( 511 | Conversation( 512 | name="rwkv", 513 | system="", 514 | roles=("Bob", "Alice"), 515 | messages=( 516 | ("Bob", "hi"), 517 | ( 518 | "Alice", 519 | "Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.", 520 | ), 521 | ), 522 | offset=2, 523 | sep_style=SeparatorStyle.RWKV, 524 | sep="", 525 | stop_str="\n\n", 526 | ) 527 | ) 528 | 529 | # Buddy default template 530 | register_conv_template( 531 | Conversation( 532 | name="openbuddy", 533 | system="""Consider a conversation between User (a human) and Assistant (named Buddy). 534 | Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy 535 | Buddy cannot access the Internet. 536 | Buddy can fluently speak the user's language (e.g. English, Chinese). 537 | Buddy can generate poems, stories, code, essays, songs, parodies, and more. 538 | Buddy possesses vast knowledge about the world, history, and culture. 539 | Buddy's responses are always safe, creative, high-quality, human-like, and interesting. 540 | Buddy strictly refuses to discuss political, NSFW, or other unsafe topics. 541 | 542 | User: Hi. 543 | Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""", 544 | roles=("User", "Assistant"), 545 | messages=(), 546 | offset=0, 547 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 548 | sep="\n", 549 | ) 550 | ) 551 | 552 | # Phoenix default template 553 | register_conv_template( 554 | Conversation( 555 | name="phoenix", 556 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", 557 | roles=("Human", "Assistant"), 558 | messages=(), 559 | offset=0, 560 | sep_style=SeparatorStyle.PHOENIX, 561 | sep="", 562 | ) 563 | ) 564 | 565 | # ChatGPT default template 566 | register_conv_template( 567 | Conversation( 568 | name="chatgpt", 569 | system="You are a helpful assistant.", 570 | roles=("user", "assistant"), 571 | messages=(), 572 | offset=0, 573 | sep_style=None, 574 | sep=None, 575 | ) 576 | ) 577 | 578 | # Claude default template 579 | register_conv_template( 580 | Conversation( 581 | name="claude", 582 | system="", 583 | roles=("Human", "Assistant"), 584 | messages=(), 585 | offset=0, 586 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 587 | sep="\n\n", 588 | ) 589 | ) 590 | 591 | # MPT default template 592 | register_conv_template( 593 | Conversation( 594 | name="mpt-7b-chat", 595 | system="""<|im_start|>system 596 | - You are a helpful assistant chatbot trained by MosaicML. 597 | - You answer questions. 598 | - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. 599 | - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""", 600 | roles=("<|im_start|>user", "<|im_start|>assistant"), 601 | messages=(), 602 | offset=0, 603 | sep_style=SeparatorStyle.CHATML, 604 | sep="<|im_end|>", 605 | stop_token_ids=[50278, 0], 606 | ) 607 | ) 608 | 609 | # MPT-30b-chat default template 610 | register_conv_template( 611 | Conversation( 612 | name="mpt-30b-chat", 613 | system="""<|im_start|>system 614 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", 615 | roles=("<|im_start|>user", "<|im_start|>assistant"), 616 | messages=(), 617 | offset=0, 618 | sep_style=SeparatorStyle.CHATML, 619 | sep="<|im_end|>", 620 | stop_token_ids=[50278, 0], 621 | ) 622 | ) 623 | 624 | # MPT-30b-instruct default template 625 | # reference: https://huggingface.co/mosaicml/mpt-30b-instruct#formatting 626 | register_conv_template( 627 | Conversation( 628 | name="mpt-30b-instruct", 629 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.", 630 | roles=("### Instruction", "### Response"), 631 | messages=(), 632 | offset=0, 633 | sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, 634 | sep="\n\n", 635 | stop_token_ids=[50278, 0], 636 | ) 637 | ) 638 | 639 | # Bard default template 640 | # Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150 641 | # https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40 642 | register_conv_template( 643 | Conversation( 644 | name="bard", 645 | system="", 646 | roles=("0", "1"), 647 | messages=(), 648 | offset=0, 649 | sep_style=None, 650 | sep=None, 651 | ) 652 | ) 653 | 654 | # BiLLa default template 655 | register_conv_template( 656 | Conversation( 657 | name="billa", 658 | system="", 659 | roles=("Human", "Assistant"), 660 | messages=(), 661 | offset=0, 662 | sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, 663 | sep="\n", 664 | stop_str="Human:", 665 | ) 666 | ) 667 | 668 | # RedPajama INCITE default template 669 | register_conv_template( 670 | Conversation( 671 | name="redpajama-incite", 672 | system="", 673 | roles=("", ""), 674 | messages=(), 675 | offset=0, 676 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 677 | sep="\n", 678 | stop_str="", 679 | ) 680 | ) 681 | 682 | # h2oGPT default template 683 | register_conv_template( 684 | Conversation( 685 | name="h2ogpt", 686 | system="", 687 | roles=("<|prompt|>", "<|answer|>"), 688 | messages=(), 689 | offset=0, 690 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 691 | sep="", 692 | ) 693 | ) 694 | 695 | # Robin default template 696 | register_conv_template( 697 | Conversation( 698 | name="Robin", 699 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", 700 | roles=("###Human", "###Assistant"), 701 | messages=(), 702 | offset=0, 703 | sep_style=SeparatorStyle.ROBIN, 704 | sep="\n", 705 | stop_token_ids=[2, 396], 706 | stop_str="###", 707 | ) 708 | ) 709 | 710 | # Snoozy default template 711 | # Reference: https://github.com/nomic-ai/gpt4all/blob/d4861030b778da6db59d21d2927a4aba4f9f1f43/gpt4all-bindings/python/gpt4all/gpt4all.py#L232 712 | register_conv_template( 713 | Conversation( 714 | name="snoozy", 715 | system="### Instruction:\nThe prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.", 716 | roles=("### Prompt", "### Response"), 717 | messages=(), 718 | offset=0, 719 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 720 | sep="\n", 721 | stop_str="###", 722 | ) 723 | ) 724 | 725 | # manticore default template 726 | register_conv_template( 727 | Conversation( 728 | name="manticore", 729 | system="", 730 | roles=("USER", "ASSISTANT"), 731 | messages=(), 732 | offset=0, 733 | sep_style=SeparatorStyle.ADD_COLON_TWO, 734 | sep="\n", 735 | sep2="", 736 | ) 737 | ) 738 | 739 | # Falcon default template 740 | register_conv_template( 741 | Conversation( 742 | name="falcon", 743 | system="", 744 | roles=("User", "Assistant"), 745 | messages=[], 746 | offset=0, 747 | sep_style=SeparatorStyle.RWKV, 748 | sep="\n", 749 | sep2="<|endoftext|>", 750 | stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text 751 | stop_token_ids=[ 752 | 0, 753 | 1, 754 | 2, 755 | 3, 756 | 4, 757 | 5, 758 | 6, 759 | 7, 760 | 8, 761 | 9, 762 | 10, 763 | 11, 764 | ], # it better only put special tokens here, because tokenizer only remove special tokens 765 | ) 766 | ) 767 | 768 | # ChagGPT default template 769 | register_conv_template( 770 | Conversation( 771 | name="polyglot_changgpt", 772 | system="", 773 | roles=("B", "A"), 774 | messages=(), 775 | offset=0, 776 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 777 | sep="\n", 778 | ) 779 | ) 780 | 781 | # tigerbot template 782 | register_conv_template( 783 | Conversation( 784 | name="tigerbot", 785 | system="A chat between a curious user and an artificial intelligence assistant. " 786 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 787 | roles=("### Instruction", "### Response"), 788 | messages=(), 789 | offset=0, 790 | sep_style=SeparatorStyle.ROBIN, 791 | sep="\n\n", 792 | stop_str="###", 793 | ) 794 | ) 795 | 796 | # ref: https://huggingface.co/Salesforce/xgen-7b-8k-inst 797 | register_conv_template( 798 | Conversation( 799 | name="xgen", 800 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", 801 | roles=("### Human: ", "###"), 802 | messages=(), 803 | offset=0, 804 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 805 | sep="\n", 806 | stop_token_ids=[50256, 0, 1, 2], 807 | stop_str="<|endoftext|>", 808 | ) 809 | ) 810 | 811 | # Internlm-chat template 812 | register_conv_template( 813 | Conversation( 814 | name="internlm-chat", 815 | system="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n", 816 | roles=("<|User|>", "<|Bot|>"), 817 | messages=(), 818 | offset=0, 819 | sep_style=SeparatorStyle.CHATINTERN, 820 | sep="", 821 | sep2="", 822 | stop_token_ids=[1, 103028], 823 | stop_str="<|User|>", 824 | ) 825 | ) 826 | 827 | # StarChat template 828 | register_conv_template( 829 | Conversation( 830 | name="starchat", 831 | system="\n", 832 | roles=("<|user|>", "<|assistant|>"), 833 | messages=(), 834 | offset=0, 835 | sep_style=SeparatorStyle.CHATML, 836 | sep="<|end|>", 837 | stop_token_ids=[0, 49155], 838 | stop_str="<|end|>", 839 | ) 840 | ) 841 | 842 | # Baichuan-13B-Chat template 843 | register_conv_template( 844 | # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507 845 | # https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json 846 | Conversation( 847 | name="baichuan-chat", 848 | system="", 849 | roles=(" ", " "), 850 | messages=(), 851 | offset=0, 852 | sep_style=SeparatorStyle.NO_COLON_TWO, 853 | sep="", 854 | sep2="", 855 | stop_token_ids=[2, 195], 856 | ) 857 | ) 858 | 859 | # llama2 template 860 | # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 861 | register_conv_template( 862 | Conversation( 863 | name="llama-2", 864 | system="[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " 865 | "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " 866 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n" 867 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " 868 | "If you don't know the answer to a question, please don't share false information.\n<>\n\n", 869 | roles=("[INST]", "[/INST]"), 870 | messages=(), 871 | offset=0, 872 | sep_style=SeparatorStyle.LLAMA2, 873 | sep=" ", 874 | sep2=" ", 875 | stop_token_ids=[2], 876 | ) 877 | ) 878 | 879 | if __name__ == "__main__": 880 | conv = get_conv_template("vicuna_v1.1") 881 | conv.append_message(conv.roles[0], "Hello!") 882 | conv.append_message(conv.roles[1], "Hi!") 883 | conv.append_message(conv.roles[0], "How are you?") 884 | conv.append_message(conv.roles[1], None) 885 | print(conv.get_prompt()) 886 | -------------------------------------------------------------------------------- /src/deita/selection/embedder/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | from datasets import Dataset 6 | from typing import Sequence, Dict 7 | from dataclasses import dataclass 8 | from deita.selection.embedder.conversation import get_conv_template 9 | 10 | IGNORE_INDEX=-100 11 | 12 | 13 | def preprocess( 14 | samples: Dataset, 15 | conv_template, 16 | only_answer, 17 | max_length, 18 | tokenizer 19 | ) -> Dict: 20 | 21 | conv = get_conv_template(conv_template) 22 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 23 | 24 | sources = samples["conversations"] 25 | sample_ids = samples["input_idx"] 26 | 27 | # Apply prompt templates 28 | conversations = [] 29 | 30 | if not only_answer: 31 | for i, source in enumerate(sources): 32 | if roles[source[0]["from"]] != conv.roles[0]: 33 | # Skip the first one if it is not from human 34 | source = source[1:] 35 | 36 | conv.messages = [] 37 | for j, sentence in enumerate(source): 38 | role = roles[sentence["from"]] 39 | # assert role == conv.roles[j % 2], f"{i}" 40 | assert role == conv.roles[j % 2], breakpoint() 41 | conv.append_message(role, sentence["value"]) 42 | conversations.append(conv.get_prompt()) 43 | else: 44 | for i, source in enumerate(sources): 45 | if roles[source[0]["from"]] != conv.roles[0]: 46 | # Skip the first one if it is not from human 47 | source = source[1:] 48 | 49 | messages = [] 50 | for j, sentence in enumerate(source): 51 | if j % 2 == 0: 52 | continue 53 | messages.append(sentence["value"]) 54 | conversations.append("\n".join(messages)) 55 | 56 | input_ids = tokenizer( 57 | conversations, 58 | return_tensors="pt", 59 | padding="longest", 60 | max_length=max_length, 61 | truncation=True, 62 | ).input_ids 63 | 64 | return dict( 65 | input_ids=input_ids, 66 | idx = sample_ids, 67 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 68 | ) 69 | 70 | @dataclass 71 | class DataCollatorForSupervisedDataset(object): 72 | """Collate examples for supervised fine-tuning.""" 73 | 74 | def __init__(self, tokenizer): 75 | self.tokenizer = tokenizer 76 | 77 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 78 | 79 | input_ids, = tuple([instance[key] for instance in instances] for key in ("input_ids",)) 80 | input_ids = torch.tensor(input_ids) 81 | 82 | input_ids = torch.nn.utils.rnn.pad_sequence( 83 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 84 | ) 85 | instance_index = torch.tensor([instance["idx"] for instance in instances]).to(input_ids.device) 86 | 87 | return dict( 88 | idx = instance_index, 89 | input_ids=input_ids, 90 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 91 | ) 92 | 93 | def set_random_seed(seed: int): 94 | """ 95 | Set the random seed for `random`, `numpy`, `torch`, `torch.cuda`. 96 | 97 | Parameters 98 | ------------ 99 | seed : int 100 | The default seed. 101 | 102 | """ 103 | random.seed(seed) 104 | np.random.seed(seed) 105 | torch.manual_seed(seed) 106 | if torch.cuda.is_available(): 107 | torch.cuda.manual_seed_all(seed) 108 | 109 | 110 | def batchlize(examples: list, batch_size: int, random_shuffle: bool, sort: bool, length_field = 'specific_length'): 111 | """ 112 | Convert examples to a dataloader. 113 | 114 | Parameters 115 | ------------ 116 | examples : list. 117 | Data list. 118 | batch_size : int. 119 | 120 | random_shuffle : bool 121 | If true, the dataloader shuffle the training data. 122 | 123 | sort: bool 124 | If true, data will be sort by its input length 125 | Returns 126 | ------------ 127 | dataloader: 128 | Dataloader with batch generator. 129 | """ 130 | size = 0 131 | dataloader = [] 132 | length = len(examples) 133 | if (random_shuffle): 134 | random.shuffle(examples) 135 | 136 | new_examples = examples 137 | if sort: 138 | new_examples = sorted(examples, key = lambda x: len(x[length_field])) 139 | 140 | while size < length: 141 | if length - size > batch_size: 142 | dataloader.append(new_examples[size : size+batch_size]) 143 | size += batch_size 144 | else: 145 | dataloader.append(new_examples[size : size+(length-size)]) 146 | size += (length - size) 147 | return dataloader 148 | 149 | def get_emb_name(**kwargs): 150 | 151 | if kwargs.get('model_path'): 152 | model_path = kwargs.pop("model_path") 153 | return os.path.basename(model_path) 154 | 155 | if kwargs.get('emb_name'): 156 | emb_name = kwargs.pop('emb_name') 157 | return os.path.basename(emb_name) -------------------------------------------------------------------------------- /src/deita/selection/filter/__init__.py: -------------------------------------------------------------------------------- 1 | from .combined_filter import Combined_Filter 2 | 3 | __all__ = ["Combined_Filter"] -------------------------------------------------------------------------------- /src/deita/selection/filter/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | class IterativeFilter(object): 9 | def __init__(self, **kwargs): 10 | 11 | self.threshold = kwargs.get('threshold') 12 | self.data_size = kwargs.get('data_size') 13 | self.sort_key = kwargs.get('sort_key') 14 | self.chunk_size = kwargs.get('chunk_size') 15 | self.batch_size = kwargs.get('batch_size', 1) 16 | self.normalize_emb = kwargs.get('normalize_emb', True) 17 | self.distance_metric = kwargs.get('distance_metric', 'cosine') 18 | self.embedding_field = kwargs.get('embedding_field', "embedding") 19 | 20 | self.device = kwargs.get('device') if kwargs.get('device') == 'cpu' else f"cuda:{kwargs.get('device')}" 21 | 22 | def compute_distance(self, matrix, matrix_2): 23 | 24 | """ 25 | Compute cosine distance using pytorch 26 | """ 27 | 28 | if self.normalize_emb: 29 | matrix = matrix / matrix.norm(dim=1)[:, None] 30 | matrix_2 = matrix_2 / matrix_2.norm(dim=1)[:, None] 31 | 32 | if self.distance_metric == 'cosine': 33 | matrix_norm = matrix / matrix.norm(dim=1)[:, None] 34 | matrix_2_norm = matrix_2 / matrix_2.norm(dim=1)[:, None] 35 | return torch.mm(matrix_norm, matrix_2_norm.t()) 36 | elif self.distance_metric == 'manhattan': 37 | return torch.cdist(matrix[None], matrix_2[None], p = 1).squeeze(0) 38 | else: 39 | raise ValueError("Metric not supported. Only support cosine and manhattan") 40 | 41 | def _sort(self, df): 42 | 43 | raise NotImplementedError 44 | 45 | def distance_chunk_by_chunk(self, existing_emb, cur_emb): 46 | 47 | distance_placeholder = torch.zeros((cur_emb.size(0), existing_emb.shape[0]), dtype = torch.float32).to(self.device) 48 | 49 | for i in range(0, existing_emb.shape[0], self.chunk_size): 50 | 51 | chunk_embeddings = existing_emb[i: i + self.chunk_size] 52 | chunk_embeddings = torch.tensor(chunk_embeddings, dtype = torch.float32).to(self.device) 53 | 54 | if chunk_embeddings.ndim == 4: 55 | chunk_embeddings = chunk_embeddings.squeeze(1).squeeze(1) 56 | 57 | distance_matrix = self.compute_distance(cur_emb, chunk_embeddings) 58 | actual_chunk = distance_matrix.size(1) 59 | 60 | distance_placeholder[:, i: i + actual_chunk] = distance_matrix 61 | 62 | return distance_placeholder 63 | 64 | def filter(self, df): 65 | 66 | logger.info(f"Data number before filtering: #{len(df)}") 67 | 68 | df_sorted = self._sort(df) 69 | 70 | embeddings = df_sorted[self.embedding_field] 71 | embeddings = np.array(embeddings.values.tolist()) 72 | 73 | filtered_indices = [0] 74 | 75 | start_cnt = 0 76 | for i in tqdm(range(1, embeddings.shape[0], self.batch_size), total = embeddings.shape[0] // self.batch_size): 77 | 78 | cur_emb = torch.tensor(embeddings[i:i+self.batch_size], dtype = torch.float32).to(self.device) 79 | 80 | if cur_emb.ndim == 4: 81 | cur_emb = cur_emb.squeeze(1).squeeze(1) 82 | 83 | if cur_emb.ndim == 1: 84 | cur_emb = cur_emb.unsqueeze(0) 85 | 86 | batch_idx = torch.range(i, i + cur_emb.size(0) - 1, dtype = torch.int64).to(self.device) 87 | 88 | existing_emb = embeddings[filtered_indices] 89 | 90 | if existing_emb.ndim == 1: 91 | existing_emb = existing_emb.unsqueeze(0) 92 | 93 | distance_existed = self.distance_chunk_by_chunk(existing_emb, cur_emb) 94 | distance_existed_bool = torch.any(distance_existed > self.threshold, dim = 1) 95 | 96 | distance_cur = self.distance_chunk_by_chunk(cur_emb, cur_emb) 97 | distance_cur = distance_cur.tril(-1) 98 | 99 | distance_cur_bool = torch.any(distance_cur > self.threshold, dim = 1) 100 | 101 | distance_bool = distance_existed_bool | distance_cur_bool 102 | 103 | filtered_indices.extend(batch_idx[~distance_bool].tolist()) 104 | 105 | if len(filtered_indices) - start_cnt > 1000: 106 | logger.info("Now data number: #{}".format(len(filtered_indices))) 107 | start_cnt = len(filtered_indices) 108 | 109 | if self.data_size > -1: 110 | if len(filtered_indices) >= self.data_size: 111 | break 112 | 113 | df_filtered = df_sorted.iloc[filtered_indices] 114 | logger.info(f"Data number after filtering: #{len(df_filtered)}") 115 | 116 | if self.data_size > -1: 117 | return df_filtered[:self.data_size] 118 | else: 119 | return df_filtered 120 | 121 | -------------------------------------------------------------------------------- /src/deita/selection/filter/combined_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import numpy as np 4 | from deita.pipeline.utils import sort_key_split 5 | from deita.selection.filter.base import IterativeFilter 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | class Combined_Filter(IterativeFilter): 10 | 11 | def __init__(self, **kwargs): 12 | 13 | super().__init__(**kwargs) 14 | 15 | def _sort(self, df): 16 | 17 | """ 18 | Sort dataframe by given method 19 | """ 20 | 21 | if isinstance(self.sort_key, list): 22 | all_sort_keys = self.sort_key 23 | else: 24 | all_sort_keys = sort_key_split(self.sort_key) 25 | 26 | logger.info("Compute final score for each sample, consider {}".format("+".join(all_sort_keys))) 27 | for sk in all_sort_keys: 28 | df[sk] = df[sk].apply(np.array) 29 | 30 | df["final_score"] = df[all_sort_keys[0]] 31 | for i in range(1, len(all_sort_keys)): 32 | df["final_score"] = df["final_score"] * df[all_sort_keys[i]] 33 | 34 | df["final_score"] = df["final_score"].apply(lambda x: x.sum()) 35 | df_sorted = df.sort_values(by = "final_score", ascending = False) 36 | 37 | return df_sorted 38 | 39 | -------------------------------------------------------------------------------- /src/deita/selection/filter/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/deita/b279f2c329b403d2612a61e270c8d2a2eeaed6f4/src/deita/selection/filter/utils.py -------------------------------------------------------------------------------- /src/deita/selection/scorer/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama_scorer import Llama_Scorer 2 | from .mistral_scorer import Mistral_Scorer 3 | 4 | __all__ = ["Llama_Scorer", "Mistral_Scorer"] -------------------------------------------------------------------------------- /src/deita/selection/scorer/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import softmax 3 | from transformers import AutoTokenizer, AutoModelForCausalLM 4 | import logging 5 | import torch 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | class Scorer(object): 10 | def __init__(self, model_name_or_path: str, is_vllm: bool = False, **kwargs): 11 | self.is_vllm = is_vllm 12 | 13 | # Automatically detecte device (GPU or CPU) 14 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | logger.info(f"Using device: {self.device}") 16 | 17 | if not is_vllm: 18 | # Load tokenizer and model, and move the model to the specified device 19 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 20 | self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(self.device) 21 | else: 22 | from vllm import LLM, SamplingParams 23 | self.llm = LLM(model_name_or_path) 24 | self.sampling_params = SamplingParams(max_tokens=2, logprobs=1000) 25 | 26 | def infer_score(self, user_input: str): 27 | max_length = 2 28 | 29 | if self.is_vllm: 30 | outputs = self.llm.generate(user_input, self.sampling_params) 31 | score_template = np.array([1, 2, 3, 4, 5, 6]) 32 | 33 | try: 34 | logprobs_list = outputs[0].outputs[0].logprobs[0] 35 | except IndexError: 36 | return 3.0 37 | else: 38 | # Encode the input as a tensor and move it to the device 39 | input_ids = self.tokenizer.encode(user_input, return_tensors="pt").to(self.device) 40 | outputs = self.model.generate( 41 | input_ids, 42 | max_new_tokens=max_length, 43 | num_return_sequences=1, 44 | return_dict_in_generate=True, 45 | output_scores=True, 46 | ) 47 | 48 | try: 49 | # Move logits to CPU and convert them to a NumPy array 50 | logprobs_list = outputs.scores[0][0].detach().cpu().numpy() 51 | except IndexError: 52 | return 3.0 53 | 54 | score_logits = [] 55 | score_template = np.array([1, 2, 3, 4, 5, 6]) 56 | for k in self.id2score: 57 | try: 58 | score_logits.append(logprobs_list[k]) 59 | except KeyError: 60 | return 3.0 61 | 62 | score_logits = np.array(score_logits) 63 | score_npy = softmax(score_logits, axis=0) 64 | score_npy = score_npy * score_template 65 | 66 | score_npy = np.sum(score_npy, axis=0) 67 | 68 | return score_npy 69 | 70 | def infer_complexity(self, input_text: str): 71 | complexity_template = self.complexity_template 72 | user_input = complexity_template.format(instruction=input_text) 73 | 74 | return self.infer_score(user_input) 75 | 76 | def infer_quality(self, input_text: str, resp_text: str): 77 | quality_template = self.quality_template 78 | user_input = quality_template.format(instruction=input_text, output=resp_text) 79 | 80 | return self.infer_score(user_input) 81 | 82 | @property 83 | def id2score(self): 84 | raise NotImplementedError 85 | 86 | @property 87 | def complexity_template(self): 88 | raise NotImplementedError 89 | 90 | @property 91 | def quality_template(self): 92 | raise NotImplementedError 93 | -------------------------------------------------------------------------------- /src/deita/selection/scorer/llama_scorer.py: -------------------------------------------------------------------------------- 1 | 2 | from deita.selection.scorer.base import Scorer 3 | 4 | class Llama_Scorer(Scorer): 5 | 6 | @property 7 | def id2score(self): 8 | 9 | id2score = { 10 | 29896: "1", 11 | 29906: "2", 12 | 29941: "3", 13 | 29946: "4", 14 | 29945: "5", 15 | 29953: "6" 16 | } 17 | 18 | return id2score 19 | 20 | @property 21 | def complexity_template(self): 22 | 23 | complexity_template = ("You are a helpful assistant. Please identify the complexity score of the following user query. \n##Query: {instruction} \n##Complexity: ") 24 | 25 | return complexity_template 26 | 27 | @property 28 | def quality_template(self): 29 | 30 | quality_template = ("You are a helpful assistant. Please identify the quality score of the Response corresponding to the Question. \n #Question#:\n{instruction}\n#Response#:\n{output} \n##Quality: ") 31 | 32 | return quality_template -------------------------------------------------------------------------------- /src/deita/selection/scorer/mistral_scorer.py: -------------------------------------------------------------------------------- 1 | 2 | from deita.selection.scorer.base import Scorer 3 | 4 | class Mistral_Scorer(Scorer): 5 | 6 | @property 7 | def id2score(self): 8 | 9 | id2score = { 10 | 28740: "1", 11 | 28750: "2", 12 | 28770: "3", 13 | 28781: "4", 14 | 28782: "5", 15 | 28784: "6" 16 | } 17 | 18 | return id2score 19 | 20 | @property 21 | def complexity_template(self): 22 | 23 | complexity_template = ("You are a helpful assistant. Please identify the complexity score of the following user query. \n##Query: {instruction} \n##Complexity: ") 24 | 25 | return complexity_template 26 | 27 | @property 28 | def quality_template(self): 29 | 30 | quality_template = ("You are a helpful assistant. Please identify the quality score of the Response corresponding to the Question. \n #Question#:\n{instruction}\n#Response#:\n{output} \n##Quality: ") 31 | 32 | return quality_template --------------------------------------------------------------------------------