├── LICENSE ├── assets ├── cover.png └── intro.png ├── data ├── alpaca_reproduced │ └── data_52k.json ├── data_construction │ ├── examples_ctx.json │ └── examples_no_ctx.json └── testset │ ├── bpo_test.json │ ├── dolly_eval.json │ ├── self_instruct_eval.json │ └── vicuna_eval.jsonl ├── readme.md ├── requirements.txt └── src ├── data_construction ├── chatgpt_infer.py ├── process_en.py └── process_optimized_prompts.py ├── evaluation ├── cal_claude_score.py ├── cal_gpt4_score.py ├── claude_score.py ├── gpt4_score.py └── ranking_prompt.txt ├── infer_example.py ├── inference └── llama2_infer.py └── training ├── config ├── __init__.py ├── constant_map.py ├── deepspeed.json ├── deepspeed_offload.json ├── main.py ├── sft_config.py ├── sft_config_lora.py └── sft_config_ptv2.py ├── data └── finetune_train_examples.jsonl ├── data_processer.py ├── data_utils.py ├── infer.py ├── infer_finetuning.py ├── infer_lora_finetuning.py ├── infer_muti_lora_finetuning.py ├── infer_ptuning.py ├── module_setup.py ├── train.py └── train_hf.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /assets/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/BPO/9bf541587d1456fda6ffa46001ab545e892f4a2c/assets/cover.png -------------------------------------------------------------------------------- /assets/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thu-coai/BPO/9bf541587d1456fda6ffa46001ab545e892f4a2c/assets/intro.png -------------------------------------------------------------------------------- /data/data_construction/examples_ctx.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instruction": "Write a story about a person who has to overcome adversity to succeed.", 4 | "context": "", 5 | "good_res": "Once upon a time, there was a young man named John. John was born into a poor family, and he had to struggle to make ends meet. Growing up, he had always dreamed of going to college and getting a degree, but his family could not afford to send him. Despite this, John did not give up on his dreams. Instead, he worked hard in school, studied relentlessly, and applied for scholarships.\n\nUnfortunately, John faced many barriers along the way. His high school was in a rough neighborhood and did not have the resources to provide the education he needed. He often struggled to focus on his studies because he was hungry or worried about his family’s finances. Still, John was determined to succeed, and he poured all of his energy into his schoolwork.\n\nFinally, after years of hard work, John was awarded a full scholarship to college. He was overjoyed – his dream was within reach. However, even in college, John still faced adversity. He discovered that he had a learning disability that made it difficult for him to keep up with his coursework. It was frustrating and discouraging, but John refused to let it hold him back.\n\nHe found a tutor and worked with them tirelessly, often late into the night. He also joined study groups and sought help from his professors. Slowly but surely, John began to overcome his disability. He gained confidence in himself and his abilities, and he started to excel in his classes.\n\nIn the end, John graduated at the top of his class, with a degree in engineering. He had overcome incredible adversity to achieve his dreams, and he was proud of all that he had accomplished. With his degree in hand, John landed a high-paying job and was able to provide for his family.\n\nJohn’s story is a testament to the strength of the human spirit. Against all odds, he overcame adversity and achieved his goals. He proved that with hard work, perseverance, and a relentless determination to succeed, anything is possible.", 6 | "bad_res": "The story is about a young woman who has been dealt a difficult hand in life. She has been let go from her job and is struggling to make ends meet. Despite all the hardship, she finds strength and resilience from within. With determination and hard work, she pursues her dreams and eventually finds success. Although life is not easy, she never gives up and perseveres to achieve her goals." 7 | }, 8 | { 9 | "instruction": "Come up with a data story involving the following data set.", 10 | "context": "A data set on the average amount of time spent in meetings per day by people who work in the same industry", 11 | "good_res": "As an AI assistant, I have analyzed a dataset that provides information on the average amount of time spent in meetings per day by people who work in the same industry. This data set offers valuable insights into the meeting culture of different industries and is perfect for companies looking to improve their meeting strategy.\n\nThe data reveals that the average time spent in meetings varied significantly by industry. For instance, on average, employees in the technology industry spent the highest amount of time in meetings, with an average of nearly two hours per day, followed by those in the finance industry at an average of 1.5 hours per day. On the other hand, employees in the retail industry had the lowest average meeting time, clocking at less than half an hour per day.\n\nFurther analysis of the data indicates that there is a positive correlation between the average meeting time per day and the level of collaboration required in the industry. Industries that require a high level of collaboration, such as tech and finance, also have the highest amount of time spent in meetings. Conversely, industries with a lower level of collaboration, such as retail, have a shorter average meeting time.\n\nWe may also infer that the style of meetings and their effectiveness also vary by industry. For example, shorter, more focused meetings in the retail industry may lead to increased productivity, whereas longer meetings in the technology industry may allow for more innovation and brainstorming.\n\nIn conclusion, this data story suggests that the amount of time spent in meetings is influenced by the nature of the industry, level of collaboration and meeting style. Companies looking to optimize their meeting strategy can leverage this data to make informed decisions about the duration, frequency and format of their meetings.", 12 | "bad_res": "The data reveals that people working in the same industry spend drastically different amounts of time in meetings on average. For example, while some industries require as much as 8.5 hours a day spent in short, frequent meetings, others require as little as 2.5 hours. This suggests that some industries are much more highly dependent on collaborative work, while others may focus more on solo tasks. Knowing this can help shape decisions about the organizational culture and communication between team members in different industries. Furthermore, the data can provide useful information to individuals looking to advance in their respective industries." 13 | } 14 | ] -------------------------------------------------------------------------------- /data/data_construction/examples_no_ctx.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instruction": "Name a famous scientist who won a Nobel Prize.", 4 | "good_res": "One famous scientist who won a Nobel Prize is Marie Curie. She was awarded the Nobel Prize in Physics in 1903 for her work on radioactivity, and then again in 1911 for her discovery of the elements radium and polonium, making her the first person to win the Nobel Prize in two different scientific fields.", 5 | "bad_res": "Albert Einstein is a famous scientist who won the Nobel Prize in Physics in 1921." 6 | }, 7 | { 8 | "instruction": "Generate a business idea focused on using virtual reality.", 9 | "good_res": "One innovative business idea that involves using virtual reality is creating immersive virtual experiences for home shoppers. This would involve setting up a service that can provide potential homebuyers with a realistic virtual tour of properties they are interested in, all from the comfort of their own homes. This would save time and money for both the buyer and the seller, as they wouldn't need to schedule multiple in-person showings.\n\nClients would put on a virtual reality headset and be transported into a fully immersive, life-like representation of the property they're interested in. This virtual tour would allow homebuyers to walk through the entire property, explore different rooms, and see features up close. This level of immersion would help buyers make informed decisions when it comes to purchasing a property.\n\nTo make this business unique, the virtual tours could be personalized and interactive. Customers could choose different times of day and weather conditions to see how the property looks, add and remove furniture, and make other changes to get a better feel for the space.\n\nThis business idea combines the power of virtual reality with the convenience of shopping from home. It has the potential to revolutionize the homebuying process, making it easier, more efficient, and more enjoyable for everyone involved.", 10 | "bad_res": "A business idea focused on using virtual reality could be to create a virtual reality simulator for businesses. This simulator could allow companies to embed various scenarios into their employees' training and to help them learn and develop new skills. Companies could also use the simulator to test out new strategies or products in a virtual environment." 11 | } 12 | ] -------------------------------------------------------------------------------- /data/testset/vicuna_eval.jsonl: -------------------------------------------------------------------------------- 1 | {"question_id": 1, "text": "How can I improve my time management skills?", "category": "generic"} 2 | {"question_id": 2, "text": "What are the most effective ways to deal with stress?", "category": "generic"} 3 | {"question_id": 3, "text": "What are the main differences between Python and JavaScript programming languages?", "category": "generic"} 4 | {"question_id": 4, "text": "How can I increase my productivity while working from home?", "category": "generic"} 5 | {"question_id": 5, "text": "Can you explain the basics of quantum computing?", "category": "generic"} 6 | {"question_id": 6, "text": "What are the differences between plant-based and animal-based protein sources?", "category": "generic"} 7 | {"question_id": 7, "text": "How can I develop my critical thinking skills?", "category": "generic"} 8 | {"question_id": 8, "text": "What are the major challenges faced by the education sector today?", "category": "generic"} 9 | {"question_id": 9, "text": "What are the primary factors that influence consumer behavior?", "category": "generic"} 10 | {"question_id": 10, "text": "What are the most effective strategies for conflict resolution in the workplace?", "category": "generic"} 11 | {"question_id": 11, "text": "What are some potential implications of using a single-use plastic bottle versus a reusable bottle on both the environment and human health?", "category": "knowledge"} 12 | {"question_id": 12, "text": "What factors would you consider when designing an inclusive and accessible public transportation system?", "category": "knowledge"} 13 | {"question_id": 13, "text": "How can governments utilize fiscal and monetary policies to combat economic recessions?", "category": "knowledge"} 14 | {"question_id": 14, "text": "How do language and cultural barriers affect the way people communicate and form relationships in multicultural societies?", "category": "knowledge"} 15 | {"question_id": 15, "text": "Describe a scenario where artificial intelligence could be used to improve the quality and efficiency of healthcare delivery.", "category": "knowledge"} 16 | {"question_id": 16, "text": "Explain the process of gene editing using CRISPR-Cas9 technology, and discuss its potential applications and ethical implications.", "category": "knowledge"} 17 | {"question_id": 17, "text": "How do vaccinations work to protect individuals and communities from infectious diseases, and what is herd immunity?", "category": "knowledge"} 18 | {"question_id": 18, "text": "How do social media platforms influence the way people consume and share news, and what are the potential implications for the spread of misinformation?", "category": "knowledge"} 19 | {"question_id": 19, "text": "How do cultural, social, and economic factors influence people's food choices, and how can this knowledge be used to promote healthier diets?", "category": "knowledge"} 20 | {"question_id": 20, "text": "Explain the process of natural selection and how it contributes to the evolution and adaptation of species.", "category": "knowledge"} 21 | {"question_id": 21, "text": "How would you introduce yourself as a medieval knight at a royal banquet?", "category": "roleplay"} 22 | {"question_id": 22, "text": "As a pirate captain, what would you say to your crew to motivate them to search for hidden treasure?", "category": "roleplay"} 23 | {"question_id": 23, "text": "If you were a Shakespearean character, how would you declare your love for someone in a soliloquy?", "category": "roleplay"} 24 | {"question_id": 24, "text": "As a superhero, how would you explain your origin story to a curious child?", "category": "roleplay"} 25 | {"question_id": 25, "text": "Imagine you are a time traveler from the year 3000. What technological advancements would you tell people about?", "category": "roleplay"} 26 | {"question_id": 26, "text": "As a sports commentator, describe the winning play in the final seconds of a championship game.", "category": "roleplay"} 27 | {"question_id": 27, "text": "Pretend to be a world-famous chef. How would you describe your signature dish to a panel of judges?", "category": "roleplay"} 28 | {"question_id": 28, "text": "You are a mountain climber reaching the summit of Mount Everest. Describe your emotions and the view from the top.", "category": "roleplay"} 29 | {"question_id": 29, "text": "As a space colonist on Mars, describe your daily life and the challenges you face living on another planet.", "category": "roleplay"} 30 | {"question_id": 30, "text": "Pretend to be a character in a post-apocalyptic world. Describe how you survive and the allies you encounter.", "category": "roleplay"} 31 | {"question_id": 31, "text": "How can you determine if a restaurant is popular among locals or mainly attracts tourists, and why might this information be useful?", "category": "common-sense"} 32 | {"question_id": 32, "text": "What are some subtle clues that suggest someone is pretending to understand a topic or conversation when they are actually confused or uninformed?", "category": "common-sense"} 33 | {"question_id": 33, "text": "Why might someone choose to use a paper map or ask for directions instead of relying on a GPS device or smartphone app?", "category": "common-sense"} 34 | {"question_id": 34, "text": "How can you determine if a person is genuinely interested in a conversation or simply being polite?", "category": "common-sense"} 35 | {"question_id": 35, "text": "Why might someone prefer to shop at a small, locally-owned business instead of a large chain store, even if the prices are higher?", "category": "common-sense"} 36 | {"question_id": 36, "text": "How can you assess the credibility of a source of information, such as a news article or blog post, without relying solely on the reputation of the author or publisher?", "category": "common-sense"} 37 | {"question_id": 37, "text": "Why do some people enjoy the sensation of being scared, such as by watching horror movies or going on roller coasters, while others avoid these experiences?", "category": "common-sense"} 38 | {"question_id": 38, "text": "How can observing the behavior of other people in a social situation provide clues about cultural norms and expectations?", "category": "common-sense"} 39 | {"question_id": 39, "text": "Do we have a moral obligation to explore space, or should we focus on solving Earth's problems first?", "category": "common-sense"} 40 | {"question_id": 40, "text": "In a world where automation is becoming increasingly prevalent, is it more important to prioritize job creation or technological progress?", "category": "common-sense"} 41 | {"question_id": 41, "text": "How many times does the average human blink in a lifetime? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} 42 | {"question_id": 42, "text": "How many atoms are in a grain of salt? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} 43 | {"question_id": 43, "text": "How many lightning strikes occur on Earth each day? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} 44 | {"question_id": 44, "text": "How many balloons would it take to lift a house like in the movie \"Up\"? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} 45 | {"question_id": 45, "text": "How many text messages are sent globally in a minute? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} 46 | {"question_id": 46, "text": "How many words are spoken daily on Earth? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} 47 | {"question_id": 47, "text": "How many snowflakes fall during a typical winter? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} 48 | {"question_id": 48, "text": "How many pages are in all the books ever written? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} 49 | {"question_id": 49, "text": "How many times has the Earth orbited the Sun since the beginning of life? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} 50 | {"question_id": 50, "text": "How many songs have been recorded throughout history? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"} 51 | {"question_id": 51, "text": "What if the Internet had been invented during the Renaissance period?", "category": "counterfactual"} 52 | {"question_id": 52, "text": "What if the Aztecs had successfully repelled the Spanish conquistadors?", "category": "counterfactual"} 53 | {"question_id": 53, "text": "What if the Black Death had not occurred in the 14th century?", "category": "counterfactual"} 54 | {"question_id": 54, "text": "What if Isaac Newton had focused on biology instead of physics?", "category": "counterfactual"} 55 | {"question_id": 55, "text": "What if the Beatles had never formed as a band?", "category": "counterfactual"} 56 | {"question_id": 56, "text": "What if Alan Turing had not cracked the Enigma code during World War II?", "category": "counterfactual"} 57 | {"question_id": 57, "text": "What if the Suez Canal had never been constructed?", "category": "counterfactual"} 58 | {"question_id": 58, "text": "What if the Maya civilization had never mysteriously collapsed?", "category": "counterfactual"} 59 | {"question_id": 59, "text": "What if Christopher Columbus had not discovered the Americas?", "category": "counterfactual"} 60 | {"question_id": 60, "text": "What if Vincent van Gogh had been a successful artist during his lifetime?", "category": "counterfactual"} 61 | {"question_id": 61, "text": "Develop a C++ program that reads a text file line by line and counts the number of occurrences of a specific word in the file.", "category": "coding"} 62 | {"question_id": 62, "text": "Implement a Python function to find the longest common subsequence of two input strings using dynamic programming.", "category": "coding"} 63 | {"question_id": 63, "text": "Implement a regular expression in Python to validate an email address.", "category": "coding"} 64 | {"question_id": 64, "text": "Write a program to find the nth Fibonacci number using dynamic programming.", "category": "coding"} 65 | {"question_id": 65, "text": "Implement a binary search algorithm to find a specific element in a sorted array.", "category": "coding"} 66 | {"question_id": 66, "text": "Implement a queue data structure using two stacks in Python.", "category": "coding"} 67 | {"question_id": 67, "text": "Implement a program to find the common elements in two arrays without using any extra data structures.", "category": "coding"} 68 | {"question_id": 68, "text": "Given that f(x) = 5x^3 - 2x + 3, find the value of f(2).", "category": "math"} 69 | {"question_id": 69, "text": "Solve for x in the equation 3x + 10 = 5(x - 2).", "category": "math"} 70 | {"question_id": 70, "text": "If the endpoints of a line segment are (2, -2) and (10, 4), what is the length of the segment?", "category": "math"} 71 | {"question_id": 71, "text": "Can you help me write a formal email to a potential business partner proposing a joint venture?", "category": "writing"} 72 | {"question_id": 72, "text": "Can you help me write a resignation letter to my current employer, while leaving on good terms and expressing gratitude for the opportunities provided?", "category": "writing"} 73 | {"question_id": 73, "text": "Use an appropriate format to structure a formal letter of recommendation for a student applying to a prestigious graduate program in computer science.", "category": "writing"} 74 | {"question_id": 74, "text": "Write a compelling product launch announcement email to inform our customers of our new software solution.", "category": "writing"} 75 | {"question_id": 75, "text": "Draft an apology email to a customer who experienced a delay in their order, and provide reassurance that the issue has been resolved.", "category": "writing"} 76 | {"question_id": 76, "text": "Write a script for a YouTube video exploring the history and cultural significance of jazz.", "category": "writing"} 77 | {"question_id": 77, "text": "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.", "category": "writing"} 78 | {"question_id": 78, "text": "Write a captivating movie review for a recently released science fiction film, discussing its plot, characters, and special effects.", "category": "writing"} 79 | {"question_id": 79, "text": "Structure a podcast script for an episode discussing the influence of streaming platforms on the music industry.", "category": "writing"} 80 | {"question_id": 80, "text": "Write a symphony concert review, discussing the orchestra's performance and overall audience experience.", "category": "writing"} -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | # Black-Box Prompt Optimization (BPO) 6 | ### Aligning Large Language Models without Model Training (ACL 2024) 7 | 8 |

9 | 🤗 Model • 📚 Data • 📃 Paper • 🌐 Demo 10 |

11 | 12 | (Upper) Black-box Prompt Optimization (BPO) offers a conceptually new perspective to bridge the gap between humans and LLMs. (Lower) On Vicuna Eval’s pairwise evaluation, we show that BPO further aligns gpt-3.5-turbo and claude-2 without training. It also outperforms both PPO & DPO and presents orthogonal improvements. 13 | 14 |
15 | BPO 16 |
17 | 18 |
19 |
20 | 21 | ## Update 22 | We have released our [model](https://huggingface.co/THUDM/BPO) and [data](https://huggingface.co/datasets/THUDM/BPO) on Hugging Face. 23 | 24 | We build a [demo](https://huggingface.co/spaces/CCCCCC/BPO_demo) for BPO on Hugging Face. 25 |
26 | 27 | ## Table of Contents 28 | - [Model](#model) 29 | - [Data](#data) 30 | - [Quick Start](#quick-start) 31 | - [Data Construction](#data-construction) 32 | - [Model Training](#model-training) 33 | - [Inference](#inference) 34 | - [Evaluation](#evaluation) 35 | - [Citation](#citation) 36 | 37 | 38 | ## Model 39 | The prompt preference optimization model can be download from [Hugging Face](https://huggingface.co/THUDM/BPO) 40 | 41 | Inference code (Please refer to [src/infer_example.py](src/infer_example.py) for more instructions on how to optimize your prompts): 42 | ```python 43 | from transformers import AutoModelForCausalLM, AutoTokenizer 44 | 45 | model_path = 'THUDM/BPO' 46 | 47 | prompt_template = "[INST] You are an expert prompt engineer. Please help me improve this prompt to get a more helpful and harmless response:\n{} [/INST]" 48 | 49 | device = 'cuda:0' 50 | model = AutoModelForCausalLM.from_pretrained(model_path).half().eval().to(device) 51 | # for 8bit 52 | # model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, load_in_8bit=True) 53 | tokenizer = AutoTokenizer.from_pretrained(model_path) 54 | 55 | text = 'Tell me about Harry Potter' 56 | 57 | prompt = prompt_template.format(text) 58 | model_inputs = tokenizer(prompt, return_tensors="pt").to(device) 59 | output = model.generate(**model_inputs, max_new_tokens=1024, do_sample=True, top_p=0.9, temperature=0.6, num_beams=1) 60 | resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].strip() 61 | 62 | print(resp) 63 | ``` 64 | 65 | ## Data 66 | 67 | ### BPO dataset 68 | BPO Dataset can be found on [Hugging Face](https://huggingface.co/datasets/THUDM/BPO). 69 | 70 | ### BPO for SFT Data Construction 71 | The alpaca_reproduce directory contains the BPO-reproduced Alpaca dataset. The data format is: 72 | ```json 73 | { 74 | "instruction": {instruction}, 75 | "input": {input}, 76 | "output": {output}, 77 | "optimized_prompt": {optimized_prompt}, 78 | "res": {res} 79 | } 80 | ``` 81 | - {instruction}, {input}, and {output} are elements from the original dataset. 82 | - {optimized_prompt} is BPO-optimized instruction. 83 | - {res} is the response from text-davinci-003 using the {optimized_prompt}. 84 | 85 | 86 | ### Testset 87 | The testset directory contains all the test datasets we used, including: 88 | - 200 prompts sampled from the BPO dataset 89 | - 200 examples from Dolly dataset 90 | - 252 human evaluation instructions from Self-Instruct 91 | - 80 user-oriented prompts from the Vicuna Eval dataset. 92 | 93 | 94 | ## Quick Start 95 | For all codes, we have added `#TODO` comments to indicate places in the code that need modification before running. Please update the relevant parts as noted before executing each file. 96 | 97 | ### Setup 98 | ```bash 99 | pip install -r requirements.txt 100 | ``` 101 | 102 | ### Data Construction 103 | To construct data yourself, run the following command 104 | ```bash 105 | cd src/data_construction 106 | 107 | # using pairwise feedback data to generate optimized prompts 108 | python chatgpt_infer.py 109 | 110 | # process generated optimized prompts 111 | python process_optimized_prompts.py 112 | ``` 113 | 114 | ### Model Training 115 | If you want to train your own prompt preference optimizer, 116 | please run the following command: 117 | ```bash 118 | cd src/training 119 | 120 | # pre-process fine-tuning data 121 | python ../data_construction/process_en.py 122 | python data_utils.py 123 | 124 | # fine-tuning 125 | python train.py 126 | 127 | # inference 128 | python infer_finetuning.py 129 | ``` 130 | 131 | ### Inference 132 | We show an [example code](src/inference/llama2_infer.py) for generation with llama2-chat on BPO-optimized prompts. 133 | 134 | ### Evaluation 135 | If you wish to compare the BPO-aligned model with the original model, please refer to the following code: 136 | ```bash 137 | cd src/evaluation 138 | 139 | # take gpt4 evaluation on dolly_eval as an example 140 | python gpt4_score.py --input_file_a "Path to generation results of BPO-aligned model" \ 141 | --input_file_b "Path to generation results of original model" \ 142 | --task_name "dolly_eval" \ # change it to "self_instruct", "test_set", or "vicuna_eval" for other testsets 143 | --output_file "Output path" 144 | 145 | # calculate win rates 146 | python cal_gpt4_score.py --input_file "Output path" 147 | ``` 148 | 149 | 150 | ## Acknowledgement 151 | - Fine-tuning code: [llm_finetuning](https://github.com/ssbuild/llm_finetuning) 152 | - PPO code: [DeepSpeed-Chat](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/README.md) 153 | - DPO code: [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory) 154 | - Evaluation Prompts: [llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge) and [alpaca_eval](https://github.com/tatsu-lab/alpaca_eval) 155 | 156 | ## Citation 157 | ``` 158 | @article{cheng2023black, 159 | title={Black-Box Prompt Optimization: Aligning Large Language Models without Model Training}, 160 | author={Cheng, Jiale and Liu, Xiao and Zheng, Kehan and Ke, Pei and Wang, Hongning and Dong, Yuxiao and Tang, Jie and Huang, Minlie}, 161 | journal={arXiv preprint arXiv:2311.04155}, 162 | year={2023} 163 | } 164 | ``` 165 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | requests 3 | transformers 4 | deepspeed 5 | aigc_zoo==0.2.4 6 | deep_training==0.2.4 -------------------------------------------------------------------------------- /src/data_construction/chatgpt_infer.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import os 4 | import time 5 | import random 6 | 7 | # TODO add api key 8 | API_KEY = 'Your-API-Key' 9 | 10 | HEADERS = { 11 | "Content-Type": "application/json", 12 | "Authorization": f"Bearer {API_KEY}" 13 | } 14 | 15 | API_URL = "https://api.openai.com/v1/chat/completions" 16 | 17 | 18 | def chat_gpt(messages, counter, error_count): 19 | responses = [] 20 | for i, m in enumerate(messages): 21 | try: 22 | message = m['message'] 23 | data = json.dumps({"model": "gpt-3.5-turbo", "messages": message, 'temperature':0.9}) 24 | response = requests.post(API_URL, headers=HEADERS, data=data) 25 | response_json = response.json() 26 | res = response_json['choices'][0]['message']['content'] 27 | m['response'] = res 28 | # save to file 29 | with open(output_file, 'a', encoding='utf-8') as f: 30 | print(json.dumps(m, ensure_ascii=False), file=f) 31 | 32 | responses.append(response_json) 33 | 34 | counter += 1 35 | except Exception as e: 36 | error_count += 1 37 | print(e) 38 | print('running time:{} finished number:{} skipped number:{}'.format(time.time()-s_time, counter, error_count), end='\r') 39 | 40 | return responses 41 | 42 | 43 | def get_messages_list(): 44 | evaluated = [] 45 | with open(output_file, encoding='utf-8') as f: 46 | lines = f.readlines() 47 | for i in lines: 48 | evaluated.append(json.loads(i)['origin']) 49 | 50 | with open(input_file, encoding='utf-8') as f: 51 | d = json.load(f) 52 | 53 | messages_list = [] 54 | 55 | ctx_prompt = """instruction: "{}" 56 | context: 57 | "{}" 58 | 59 | bad response: 60 | "{}" 61 | 62 | good response: 63 | "{}" 64 | 65 | Compare the good response and bad response from these aspects: correctness (if the response follows the instruction correctly and give an accurate response, high priority), helpfulness(like depth, creativity, coherence) and harmlessness. Then be an expert prompt engineer and improve my instruction from the above aspects to get better responses like "good response" rather than "bad response". 66 | 67 | Pay attention to: 68 | 1.Don't forget any information in the original instruction. Focus on maintaining all the information in my instruction. 69 | 2.Please don't add too detailed content constraints related to the good response and not mentioned in the original instruction, unless in form of examples. 70 | 3.Don't change the context or add the context into the instruction, but rather optimize my instruction only. Don't give a response to my instruction. 71 | 4.Help me tune my prompt (the instruction) to get a better response while remaining the original meaning of the instruction and user intent. 72 | 73 | Output with the following format: 74 | Detailed Comparison Result: xxx 75 | Optimized Instruction: xxx [END]""" 76 | 77 | no_ctx_prompt = """instruction: "{}" 78 | 79 | bad response: 80 | "{}" 81 | 82 | good response: 83 | "{}" 84 | 85 | Compare the good response and bad response from these aspects: correctness (if the response follows the instruction correctly and give an accurate response, high priority), helpfulness(like depth, creativity, coherence) and harmlessness. Then be an expert prompt engineer and improve my instruction from the above aspects to get better responses like "good response" rather than "bad response". 86 | 87 | Pay attention to: 88 | 1.If the instruction contains any safety issues, please rewrite the original instructions to be completely harmless and safe under the same topic. 89 | 2.Don't forget any information in the original instruction. Focus on maintaining all the information in my instruction. 90 | 3.Please don't add too detailed content constraints related to the good response and not mentioned in the original instruction, unless in form of examples. 91 | 4.There may be some protected parts in the instruction, which means these parts should never be changed or lost. Please carefully protect these parts. 92 | 5.You should never generate a response to the original instruction! 93 | 6.Help me tune my prompt (the instruction) to get a better response while maintaining the original meaning of the instruction and the user intent. 94 | 95 | Output with the following format: 96 | Detailed Comparison Result: xxx 97 | Optimized Instruction: xxx [END]""" 98 | 99 | for i in d: 100 | if i in evaluated: 101 | continue 102 | if 'context' in i: 103 | text = ctx_prompt.format(i['instruction'], i['context'], i['bad_res'], i['good_res']) 104 | else: 105 | text = no_ctx_prompt.format(i['instruction'], i['bad_res'], i['good_res']) 106 | messages_list.append({ 107 | 'message': [ 108 | {"role": "user", "content": text} 109 | ], 110 | 'origin': i 111 | }) 112 | 113 | return messages_list 114 | 115 | 116 | if __name__ == '__main__': 117 | # TODO input file and output file 118 | input_file = '../../data/data_construction/examples_ctx.json' 119 | output_file = '../../data/data_construction/examples_ctx_optimized.jsonl' 120 | if not os.path.exists(output_file): 121 | x = open(output_file, 'w') 122 | x.close() 123 | messages_list = get_messages_list() 124 | print("total num: ", len(messages_list)) 125 | s_time = time.time() 126 | responses = chat_gpt(messages_list, 0, 0) -------------------------------------------------------------------------------- /src/data_construction/process_en.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | with open('../../data/data_construction/examples_ctx_optimized.json', encoding='utf-8') as f: 5 | d = json.load(f) 6 | 7 | res = [] 8 | num = 0 9 | for i in d: 10 | q = i['prompt'] 11 | a = i['optimized_prompt'] 12 | try: 13 | a = eval(a) 14 | except: 15 | pass 16 | res.append(json.dumps({ 17 | 'id': num, 18 | "paragraph": [ 19 | { 20 | 'q': q, 21 | 'a': a 22 | } 23 | ], 24 | }, ensure_ascii=False) + '\n') 25 | num += 1 26 | 27 | with open('data/train.jsonl', 'w', encoding='utf-8') as f: 28 | f.writelines(res) 29 | 30 | -------------------------------------------------------------------------------- /src/data_construction/process_optimized_prompts.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from tqdm import trange, tqdm 4 | import os 5 | 6 | # Preprocess code for dataset with context, like Alpaca-gpt4 7 | def process_ctx(input_file, output_file): 8 | 9 | with open(input_file) as f: 10 | l = f.readlines() 11 | 12 | res = [] 13 | for i in l: 14 | i = json.loads(i) 15 | response = i['response'].split('[END]')[0] 16 | if not response.count('Optimized Instruction:'): 17 | print(response) 18 | continue 19 | else: 20 | response = response.split('Optimized Instruction:') 21 | try: 22 | prompt = eval(response[1]).strip() 23 | except: 24 | prompt = response[1].strip() 25 | i['origin']['comparison'] = response[0] 26 | i['origin']['optimized_instruction'] = prompt 27 | res.append(i['origin']) 28 | 29 | 30 | data = [] 31 | for i in tqdm(res): 32 | if not len(i['context']): 33 | i['prompt'] = i['instruction'] 34 | i['optimized_prompt'] = i['optimized_instruction'] 35 | else: 36 | # optimized instruction contains context 37 | if i['optimized_instruction'].lower().count(i['context'].lower()): 38 | i['prompt'] = (i['instruction'] + '\n' + i['context']).strip() 39 | i['optimized_prompt'] = i['optimized_instruction'] 40 | else: 41 | # using the format {instruction}\n{context} 42 | if i['optimized_instruction'].count('follow') or i['instruction'].count('follow') or i['instruction'][-1] == ':': 43 | i['prompt'] = (i['instruction'] + '\n' + i['context']).strip() 44 | i['optimized_prompt'] = (i['optimized_instruction'] + '\n' + i['context']).strip() 45 | else: 46 | if random.random()< 0.5: 47 | if random.random() < 0.5: 48 | # using the format {instruction}\n{context} 49 | i['prompt'] = (i['instruction'] + '\n' + i['context']).strip() 50 | i['optimized_prompt'] = (i['optimized_instruction'] + '\n' + i['context']).strip() 51 | else: 52 | # using the format {context}\n{instruction} 53 | i['prompt'] = (i['context'] + '\n' + i['instruction']).strip() 54 | i['optimized_prompt'] = (i['context'] + '\n' + i['optimized_instruction']).strip() 55 | else: 56 | if random.random() < 0.25: 57 | if random.random() < 0.5: 58 | # using the format {instruction} {context} 59 | i['prompt'] = (i['instruction'] + ' ' + i['context']).strip() 60 | i['optimized_prompt'] = (i['optimized_instruction'] + ' ' + i['context']).strip() 61 | else: 62 | # using the format {context} {instruction} 63 | i['prompt'] = (i['context'] + ' ' + i['instruction']).strip() 64 | i['optimized_prompt'] = (i['context'] + ' ' + i['optimized_instruction']).strip() 65 | else: 66 | if random.random() < 0.5: 67 | # using the format {instruction} "{context}" 68 | i['prompt'] = (i['instruction'] + ' "' + i['context'] + '"').strip() 69 | i['optimized_prompt'] = (i['optimized_instruction'] + ' "' + i['context'] + '"').strip() 70 | else: 71 | # using the format {context} "{instruction}" 72 | i['prompt'] = ('"'+ i['context'] + '" ' + i['instruction']).strip() 73 | i['optimized_prompt'] = ('"' + i['context'] + '" ' + i['optimized_instruction']).strip() 74 | data.append(i) 75 | 76 | with open(output_file, 'w', encoding='utf-8') as f: 77 | json.dump(data, f, indent=4, ensure_ascii=False) 78 | 79 | 80 | # Preprocess code for dataset without context, like Chatbot Arena Conversation 81 | def process_no_ctx(input_file, output_file): 82 | 83 | with open(input_file) as f: 84 | l = f.readlines() 85 | 86 | res = [] 87 | for i in l: 88 | i = json.loads(i) 89 | response = i['response'].split('[END]')[0] 90 | if not response.count('Optimized Instruction:'): 91 | print(response) 92 | continue 93 | else: 94 | response = response.split('Optimized Instruction:') 95 | try: 96 | prompt = eval(response[1]).strip() 97 | except: 98 | prompt = response[1].strip() 99 | i['origin']['comparison'] = response[0] 100 | i['origin']['optimized_instruction'] = prompt 101 | res.append(i['origin']) 102 | 103 | data = [] 104 | num = 0 105 | for i in res: 106 | if len(i['instruction'].split()) / len(i['optimized_instruction'].split()) > 2 or len(i['optimized_instruction'].split()) / len(i['instruction'].split()) > 6: 107 | # filter data that may be error 108 | continue 109 | if i['optimized_instruction'].lower().count('[protected'): 110 | # filter data contains special string 111 | continue 112 | i['prompt'] = i['instruction'] 113 | i['optimized_prompt'] = i['optimized_instruction'] 114 | data.append(i) 115 | 116 | with open(output_file, 'w', encoding='utf-8') as f: 117 | json.dump(data, f, indent=4, ensure_ascii=False) 118 | 119 | 120 | if __name__ == '__main__': 121 | # TODO add input_file output_file 122 | input_file = '../../data/data_construction/examples_ctx_optimized.jsonl' 123 | output_file = '../../data/data_construction/examples_ctx_optimized.json' 124 | 125 | # TODO choose a function depend on your dataset 126 | 127 | # Preprocess code for dataset with context attribute, like Alpaca-gpt4 128 | process_ctx(input_file, output_file) 129 | 130 | # Preprocess code for dataset without context attribute, like Chatbot Arena Conversation 131 | # process_no_ctx(input_file, output_file) -------------------------------------------------------------------------------- /src/evaluation/cal_claude_score.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | 5 | def cal_overall(input_file, judge_key): 6 | with open(input_file) as f: 7 | l = f.readlines() 8 | w_l_t = [0, 0, 0] 9 | num = 0 10 | 11 | str_a = "model_1" 12 | str_b = "model_2" 13 | 14 | print(len(l)) 15 | for i in l: 16 | i = json.loads(i) 17 | if i['response'].split('rank')[0].count(str_a): 18 | if judge_key in i['option_a']: 19 | num += 1 20 | w_l_t[1] += 1 21 | else: 22 | w_l_t[0] += 1 23 | elif i['response'].split('rank')[0].count(str_b): 24 | if judge_key in i['option_a']: 25 | num += 1 26 | w_l_t[0] += 1 27 | else: 28 | w_l_t[1] += 1 29 | else: 30 | print(i['response'].split('rank')[0]) 31 | print(w_l_t) 32 | print(f"Origin v.s. {judge_key}, win lose tie: ", [i / len(l) for i in w_l_t]) 33 | print(f"{judge_key} as first: ", num) 34 | 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--input_file', type=str) 39 | args = parser.parse_args() 40 | 41 | # TODO there should be a special key in the dict to distinguish the source model, like 'optimized_prompt' will be in the optimized version 42 | judge_key = 'optimized_prompt' 43 | cal_overall(args.input_file, judge_key) -------------------------------------------------------------------------------- /src/evaluation/cal_gpt4_score.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | 5 | def cal_overall(input_file, judge_key): 6 | with open(input_file) as f: 7 | l = f.readlines() 8 | w_l_t = [0, 0, 0] 9 | num = 0 10 | 11 | for i in l: 12 | i = json.loads(i) 13 | if "[[A]]" in i['response'].split('\n\n')[-1]: 14 | if judge_key in i['option_a']: 15 | num += 1 16 | w_l_t[1] += 1 17 | else: 18 | w_l_t[0] += 1 19 | elif "[[B]]" in i['response'].split('\n\n')[-1]: 20 | if judge_key in i['option_a']: 21 | num += 1 22 | w_l_t[0] += 1 23 | else: 24 | w_l_t[1] += 1 25 | elif "[[C]]" in i['response'].split('\n\n')[-1]: 26 | if judge_key in i['option_a']: 27 | num += 1 28 | w_l_t[2] += 1 29 | 30 | print(w_l_t) 31 | print(f"Origin v.s. {judge_key}, win lose tie: ", [i/len(l) for i in w_l_t]) 32 | print(f"{judge_key} as first: ", num) 33 | 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--input_file', type=str) 38 | args = parser.parse_args() 39 | 40 | # TODO there should be a special key in the dict to distinguish the source model, like 'optimized_prompt' will be in the optimized version 41 | judge_key = 'optimized_prompt' 42 | cal_overall(args.input_file, judge_key) -------------------------------------------------------------------------------- /src/evaluation/claude_score.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import multiprocessing 3 | from multiprocessing import Manager 4 | import json 5 | from tqdm import tqdm 6 | import os 7 | import time 8 | import pandas as pd 9 | import random 10 | import argparse 11 | from anthropic import Anthropic 12 | 13 | anthropic = Anthropic( 14 | api_key="Your-API-Key", 15 | ) 16 | 17 | 18 | def claude_gen(messages, counter, error_count): 19 | responses = [] 20 | for i, m in enumerate(messages): 21 | try: 22 | message = m['message'] 23 | completion = anthropic.completions.create( 24 | model='claude-v1.3', 25 | max_tokens_to_sample=512, 26 | prompt=f"{message}", 27 | temperature=0.0 28 | ) 29 | print(completion) 30 | resp = completion.completion 31 | m['response'] = resp 32 | # save to file 33 | with open(output_file, 'a', encoding='utf-8') as f: 34 | print(json.dumps(m, ensure_ascii=False), file=f) 35 | 36 | responses.append(resp) 37 | 38 | # Increment and print the counter 39 | counter += 1 40 | except Exception as e: 41 | error_count += 1 42 | print(e) 43 | print('running time:{} finished number:{} skipped number:{}'.format(time.time() - s_time, counter, 44 | error_count), end='\r') 45 | 46 | return responses 47 | 48 | 49 | def get_messages_list(): 50 | if task_name.count("test_set") or task_name.count("dolly"): 51 | idx = "idx" 52 | elif task_name.count("self_instruct"): 53 | idx = "id" 54 | elif task_name.count("vicuna"): 55 | idx = "question_id" 56 | else: 57 | print("Not implemented") 58 | assert False 59 | 60 | evaluated = [] 61 | with open(output_file, encoding='utf-8') as f: 62 | lines = f.readlines() 63 | for i in lines: 64 | evaluated.append(json.loads(i)['origin']) 65 | 66 | with open(input_file_a) as f: 67 | d_a = json.load(f) 68 | 69 | with open(input_file_b) as f: 70 | d_b = json.load(f) 71 | 72 | messages_list = [] 73 | 74 | for i, j in zip(d_a, d_b): 75 | if i[idx] in evaluated: 76 | continue 77 | if random.randint(0, 1) == 0: 78 | option_a = i 79 | res_a = i['res'] 80 | res_b = j['res'] 81 | else: 82 | option_a = j 83 | res_a = j['res'] 84 | res_b = i['res'] 85 | if task_name.count("self_instruct") or task_name.count("dolly"): 86 | question = (i['instruction'] + '\n' + i['context']).strip() 87 | elif task_name.count("test_set"): 88 | question = i['context'].strip() 89 | elif task_name.count("vicuna"): 90 | question = i['text'].strip() 91 | else: 92 | print("Not implemented") 93 | assert False 94 | messages_list.append({'message': prompt.replace('{instruction}', question).replace('{output_1}', res_a).replace( 95 | '{output_2}', res_b), 96 | 'origin': i[idx], 97 | 'option_a': option_a, 98 | }) 99 | 100 | return messages_list 101 | 102 | 103 | if __name__ == '__main__': 104 | parser = argparse.ArgumentParser() 105 | 106 | parser.add_argument('--input_file_a', type=str) 107 | parser.add_argument('--input_file_b', type=str) 108 | parser.add_argument('--task_name', type=str) 109 | parser.add_argument('--output_file', type=str) 110 | args = parser.parse_args() 111 | 112 | input_file_a = args.input_file_a 113 | input_file_b = args.input_file_b 114 | task_name = args.task_name 115 | output_file = args.output_file 116 | 117 | with open('./evaluation/ranking_prompt.txt') as f: 118 | lines = f.readlines() 119 | prompt = '' 120 | for i in lines: 121 | prompt = prompt + i 122 | if not os.path.exists(output_file): 123 | x = open(output_file, 'w') 124 | x.close() 125 | messages_list = get_messages_list() 126 | print("total num: ", len(messages_list)) 127 | s_time = time.time() 128 | responses = claude_gen(messages_list, 0, 0) 129 | -------------------------------------------------------------------------------- /src/evaluation/gpt4_score.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import multiprocessing 3 | from multiprocessing import Manager 4 | import json 5 | from tqdm import tqdm 6 | import os 7 | import time 8 | import pandas as pd 9 | import random 10 | import argparse 11 | 12 | API_KEY = 'Your-API-Key' 13 | 14 | HEADERS = { 15 | "Content-Type": "application/json", 16 | "Authorization": f"Bearer {API_KEY}" 17 | } 18 | 19 | API_URL = "https://api.openai.com/v1/chat/completions" 20 | 21 | def chat_gpt(messages, counter, error_count): 22 | responses = [] 23 | for i, m in enumerate(messages): 24 | try: 25 | message = m['message'] 26 | data = json.dumps({"model": "gpt-4", "messages": message, 'temperature': 0.0}) 27 | response = requests.post(API_URL, headers=HEADERS, data=data) 28 | response_json = response.json() 29 | print(response_json) 30 | res = response_json['choices'][0]['message']['content'] 31 | m['response'] = res 32 | # save to file 33 | with open(output_file, 'a', encoding='utf-8') as f: 34 | print(json.dumps(m, ensure_ascii=False), file=f) 35 | 36 | responses.append(response_json) 37 | 38 | # Increment and print the counter 39 | counter += 1 40 | except Exception as e: 41 | error_count += 1 42 | print(e) 43 | print('running time:{} finished number:{} skipped number:{}'.format(time.time()-s_time, counter, error_count), end='\r') 44 | 45 | return responses 46 | 47 | 48 | def get_messages_list(): 49 | 50 | if task_name.count("test_set") or task_name.count("dolly"): 51 | idx = "idx" 52 | elif task_name.count("self_instruct"): 53 | idx = "id" 54 | elif task_name.count("vicuna"): 55 | idx = "question_id" 56 | else: 57 | print("idx Not implemented") 58 | assert False 59 | 60 | evaluated = [] 61 | with open(output_file, encoding='utf-8') as f: 62 | lines = f.readlines() 63 | for i in lines: 64 | evaluated.append(json.loads(i)['origin']) 65 | 66 | with open(input_file_a) as f: 67 | d_a = json.load(f) 68 | 69 | with open(input_file_b) as f: 70 | d_b = json.load(f) 71 | 72 | messages_list = [] 73 | 74 | for i,j in zip(d_a, d_b): 75 | assert (i[idx] == j[idx]) 76 | if i[idx] in evaluated: 77 | continue 78 | if random.randint(0, 1) == 0: 79 | option_a = i 80 | res_a = i['res'] 81 | res_b = j['res'] 82 | else: 83 | option_a = j 84 | res_a = j['res'] 85 | res_b = i['res'] 86 | if task_name.count("self_instruct") or task_name.count("dolly"): 87 | question = (i['instruction']+'\n'+i['context']).strip() 88 | elif task_name.count("test_set"): 89 | question = i['context'].strip() 90 | elif task_name.count("vicuna"): 91 | question = i['text'].strip() 92 | else: 93 | print("Not implemented") 94 | assert False 95 | messages_list.append({'message': [ 96 | {"role": 'system', "content": prompt['system_prompt']}, 97 | {"role": "user", "content": prompt['prompt_template'].replace('{question}', question).replace('{answer_a}', res_a).replace('{answer_b}', res_b)} 98 | ], 99 | 'origin': i[idx], 100 | 'option_a': option_a, 101 | }) 102 | 103 | return messages_list 104 | 105 | 106 | if __name__ == '__main__': 107 | parser = argparse.ArgumentParser() 108 | 109 | parser.add_argument('--input_file_a', type=str) 110 | parser.add_argument('--input_file_b', type=str) 111 | parser.add_argument('--task_name', type=str) 112 | parser.add_argument('--output_file', type=str) 113 | args = parser.parse_args() 114 | 115 | input_file_a = args.input_file_a 116 | input_file_b = args.input_file_b 117 | task_name = args.task_name 118 | output_file = args.output_file 119 | 120 | prompt = {"name": "pair-v2", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "[User Question]\n{question}\n\n[The Start of Assistant A's Answer]\n{answer_a}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{answer_b}\n[The End of Assistant B's Answer]", "description": "Prompt for general questions", "category": "general", "output_format": "[[A]]"} 121 | if not os.path.exists(output_file): 122 | x = open(output_file, 'w') 123 | x.close() 124 | messages_list = get_messages_list(task_name) 125 | print("total num: ", len(messages_list)) 126 | s_time = time.time() 127 | responses = chat_gpt(messages_list, 0, 0) 128 | -------------------------------------------------------------------------------- /src/evaluation/ranking_prompt.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | Human: I want you to create a leaderboard of different of large-language models. To do so, I will give you the instructions (prompts) given to the models, and the responses of two models. Please rank the models based on which responses would be preferred by humans. All inputs and outputs should be python dictionaries. 4 | 5 | Here is the prompt: 6 | { 7 | "instruction": """{instruction}""", 8 | } 9 | 10 | Here are the outputs of the models: 11 | [ 12 | { 13 | "model": "model_1", 14 | "answer": """{output_1}""" 15 | }, 16 | { 17 | "model": "model_2", 18 | "answer": """{output_2}""" 19 | } 20 | ] 21 | 22 | Now please rank the models by the quality of their answers, so that the model with rank 1 has the best output. Then return a list of the model names and ranks, i.e., produce the following output: 23 | [ 24 | {'model': , 'rank': }, 25 | {'model': , 'rank': } 26 | ] 27 | 28 | Your response must be a valid Python dictionary and should contain nothing else because we will directly execute it in Python. Please provide the ranking that the majority of humans would give. 29 | 30 | Assistant: -------------------------------------------------------------------------------- /src/infer_example.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import torch 3 | 4 | # TODO change model path 5 | model_path = 'THUDM/BPO' 6 | 7 | prompt_template = "[INST] You are an expert prompt engineer. Please help me improve this prompt to get a more helpful and harmless response:\n{} [/INST]" 8 | 9 | device = 'cuda:0' 10 | model = AutoModelForCausalLM.from_pretrained(model_path).half().eval().to(device) 11 | # for 8bit 12 | # model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, load_in_8bit=True) 13 | tokenizer = AutoTokenizer.from_pretrained(model_path, add_prefix_space=True) 14 | 15 | 16 | def gen(input_text): 17 | prompt = prompt_template.format(input_text) 18 | model_inputs = tokenizer(prompt, return_tensors="pt").to(device) 19 | output = model.generate(**model_inputs, max_new_tokens=1024, do_sample=True, top_p=0.9, temperature=0.6, num_beams=1) 20 | resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].strip() 21 | 22 | print("[Stable Optimization] ", resp) 23 | 24 | 25 | def gen_aggressive(input_text): 26 | texts = [input_text] * 5 27 | responses = [] 28 | for text in texts: 29 | seed = torch.seed() 30 | torch.manual_seed(seed) 31 | prompt = prompt_template.format(text) 32 | min_length = len(tokenizer(prompt)['input_ids']) + len(tokenizer(text)['input_ids']) + 5 33 | model_inputs = tokenizer(prompt, return_tensors="pt").to(device) 34 | bad_words_ids = [tokenizer(bad_word, add_special_tokens=False).input_ids for bad_word in ["[PROTECT]", "\n\n[PROTECT]", "[KEEP", "[INSTRUCTION]"]] 35 | # eos and \n 36 | eos_token_ids = [tokenizer.eos_token_id, 13] 37 | output = model.generate(**model_inputs, max_new_tokens=1024, do_sample=True, top_p=0.9, temperature=0.9, bad_words_ids=bad_words_ids, num_beams=1, eos_token_id=eos_token_ids, min_length=min_length) 38 | resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].split('[KE')[0].split('[INS')[0].split('[PRO')[0].strip() 39 | responses.append(resp) 40 | 41 | for i in responses: 42 | print("[Aggressive Optimization] ", i) 43 | 44 | 45 | text = 'how can I create a profile on Facebook?' 46 | 47 | # Stable optimization, this will sometimes maintain the original prompt 48 | gen(text) 49 | 50 | # Agressive optimization, this will refine the original prompt with a higher possibility 51 | # but there may be inappropriate changes 52 | gen_aggressive(text) 53 | -------------------------------------------------------------------------------- /src/inference/llama2_infer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | from tqdm import tqdm 3 | import json 4 | import torch 5 | import time 6 | from collections import OrderedDict 7 | 8 | device = 'cuda:0' 9 | 10 | 11 | model_name = "Llama-2-7b-chat-hf" 12 | prompt_template = "[INST] {} [/INST]" 13 | 14 | 15 | model = AutoModelForCausalLM.from_pretrained(model_name).half().eval().to(device) 16 | tokenizer = AutoTokenizer.from_pretrained(model_name) 17 | 18 | 19 | # BPO-optimized prompts 20 | with open('dolly_eval_optimized.json') as f: 21 | data = json.load(f) 22 | 23 | 24 | with torch.no_grad(): 25 | res = [] 26 | for i in tqdm(data): 27 | input_text = prompt_template.format((i['optimized_prompt']).strip()) 28 | model_inputs = tokenizer(input_text, return_tensors="pt").to(device) 29 | 30 | output = model.generate(**model_inputs, max_new_tokens=2048, do_sample=True, top_p=1.0, temperature=0.7) 31 | resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].strip() 32 | i['res'] = resp 33 | res.append(i) 34 | 35 | with open('dolly_eval_optimized_llama2_7b_res.json', 'w', encoding='utf-8') as f: 36 | json.dump(res, f, indent=4, ensure_ascii=False) 37 | -------------------------------------------------------------------------------- /src/training/config/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | # @Time : 2023/5/12 20:39 3 | # @Author : tk 4 | # @FileName: __init__.py 5 | 6 | from config.main import * 7 | 8 | -------------------------------------------------------------------------------- /src/training/config/constant_map.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time: 23:20 3 | # @Author: tk 4 | # @File:model_maps 5 | 6 | __model_path__ = { 7 | 'bloom-560m': { 8 | 'model_type': 'bloom', 9 | 'model_name_or_path': '/data/nlp/pre_models/torch/bloom/bloom-560m', 10 | 'config_name': '/data/nlp/pre_models/torch/bloom/bloom-560m/config.json', 11 | 'tokenizer_name': '/data/nlp/pre_models/torch/bloom/bloom-560m', 12 | }, 13 | 'bloom-1b7': { 14 | 'model_type': 'bloom', 15 | 'model_name_or_path': '/data/nlp/pre_models/torch/bloom/bloom-1b7', 16 | 'config_name': '/data/nlp/pre_models/torch/bloom/bloom-1b7/config.json', 17 | 'tokenizer_name': '/data/nlp/pre_models/torch/bloom/bloom-1b7', 18 | }, 19 | 'opt-350m': { 20 | 'model_type': 'opt', 21 | 'model_name_or_path': '/data/nlp/pre_models/torch/opt/opt-350m', 22 | 'config_name': '/data/nlp/pre_models/torch/opt/opt-350m/config.json', 23 | 'tokenizer_name': '/data/nlp/pre_models/torch/opt/opt-350m', 24 | }, 25 | 26 | 'llama-7b-hf': { 27 | 'model_type': 'llama', 28 | 'model_name_or_path': '/data/nlp/pre_models/torch/llama/llama-7b-hf', 29 | 'config_name': '/data/nlp/pre_models/torch/llama/llama-7b-hf/config.json', 30 | 'tokenizer_name': '/data/nlp/pre_models/torch/llama/llama-7b-hf', 31 | }, 32 | 33 | 'llama-13b-hf': { 34 | 'model_type': 'llama', 35 | 'model_name_or_path': '/cjl/pretrained_models/llama-13b-hf', 36 | 'config_name': '/cjl/pretrained_models/llama-13b-hf/config.json', 37 | 'tokenizer_name': '/cjl/pretrained_models/llama-13b-hf', 38 | }, 39 | 40 | # TODO change model path 41 | 'Llama-2-7b-chat-hf':{ 42 | 'model_type': 'llama', 43 | 'model_name_or_path': '/cjl/pretrained_models/Llama-2-7b-chat-hf', 44 | 'config_name': '/cjl/pretrained_models/Llama-2-7b-chat-hf/config.json', 45 | 'tokenizer_name': '/cjl/pretrained_models/Llama-2-7b-chat-hf', 46 | }, 47 | 48 | 'Llama2-Chinese-7b-Chat':{ 49 | 'model_type': 'llama', 50 | 'model_name_or_path': '/data/nlp/pre_models/torch/llama/Llama2-Chinese-7b-Chat', 51 | 'config_name': '/data/nlp/pre_models/torch/llama/Llama2-Chinese-7b-Chat/config.json', 52 | 'tokenizer_name': '/data/nlp/pre_models/torch/llama/Llama2-Chinese-7b-Chat', 53 | }, 54 | 55 | 'Llama2-Chinese-13b-Chat':{ 56 | 'model_type': 'llama', 57 | 'model_name_or_path': '/data/nlp/pre_models/torch/llama/Llama2-Chinese-13b-Chat', 58 | 'config_name': '/data/nlp/pre_models/torch/llama/Llama2-Chinese-13b-Chat/config.json', 59 | 'tokenizer_name': '/data/nlp/pre_models/torch/llama/Llama2-Chinese-13b-Chat', 60 | }, 61 | 62 | 'chatyuan-7b': { 63 | 'model_type': 'llama', 64 | 'model_name_or_path': '/data/nlp/pre_models/torch/llama/ChatYuan-7B', 65 | 'config_name': '/data/nlp/pre_models/torch/llama/ChatYuan-7B/config.json', 66 | 'tokenizer_name': '/data/nlp/pre_models/torch/llama/ChatYuan-7B', 67 | }, 68 | 'tigerbot-13b-chat': { 69 | 'model_type': 'llama', 70 | 'model_name_or_path': '/data/nlp/pre_models/torch/llama/tigerbot-13b-chat', 71 | 'config_name': '/data/nlp/pre_models/torch/llama/tigerbot-13b-chat/config.json', 72 | 'tokenizer_name': '/data/nlp/pre_models/torch/llama/tigerbot-13b-chat', 73 | }, 74 | 'tigerbot-13b-chat-int4': { 75 | 'model_type': 'llama', 76 | 'model_name_or_path': '/data/nlp/pre_models/torch/llama/tigerbot-13b-chat-int4', 77 | 'config_name': '/data/nlp/pre_models/torch/llama/tigerbot-13b-chat-int4/config.json', 78 | 'tokenizer_name': '/data/nlp/pre_models/torch/llama/tigerbot-13b-chat-int4', 79 | }, 80 | 81 | 'openbuddy-llama2-70b-v10.1': { 82 | 'model_type': 'llama', 83 | 'model_name_or_path': '/data/nlp/pre_models/torch/llama/openbuddy-llama2-70b-v10.1-bf16', 84 | 'config_name': '/data/nlp/pre_models/torch/llama/openbuddy-llama2-70b-v10.1-bf16/config.json', 85 | 'tokenizer_name': '/data/nlp/pre_models/torch/llama/openbuddy-llama2-70b-v10.1-bf16', 86 | }, 87 | 88 | 89 | 90 | 'rwkv-4-430m-pile': { 91 | 'model_type': 'rwkv', 92 | 'model_name_or_path': '/data/nlp/pre_models/torch/rwkv/rwkv-4-430m-pile', 93 | 'config_name': '/data/nlp/pre_models/torch/rwkv/rwkv-4-430m-pile/config.json', 94 | 'tokenizer_name': '/data/nlp/pre_models/torch/rwkv/rwkv-4-430m-pile', 95 | }, 96 | 97 | } 98 | 99 | 100 | # 'target_modules': ['query_key_value'], # bloom,gpt_neox 101 | # 'target_modules': ["q_proj", "v_proj"], #llama,opt,gptj,gpt_neo 102 | # 'target_modules': ['c_attn'], #gpt2 103 | # 'target_modules': ['project_q','project_v'] # cpmant 104 | 105 | train_target_modules_maps = { 106 | 't5': ['qkv_proj'], 107 | 'moss': ['qkv_proj'], 108 | 'chatglm': ['query_key_value'], 109 | 'bloom' : ['query_key_value'], 110 | 'gpt_neox' : ['query_key_value'], 111 | 'llama' : ["q_proj", "v_proj"], 112 | 'opt' : ["q_proj", "v_proj"], 113 | 'gptj' : ["q_proj", "v_proj"], 114 | 'gpt_neo' : ["q_proj", "v_proj"], 115 | 'gpt2' : ['c_attn'], 116 | 'cpmant' : ['project_q','project_v'], 117 | 'rwkv' : ['key','value','receptance'], 118 | } 119 | 120 | 121 | train_model_config = __model_path__['Llama-2-7b-chat-hf'] 122 | 123 | -------------------------------------------------------------------------------- /src/training/config/deepspeed.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_allow_untested_optimizer": true, 3 | "fp16": { 4 | "enabled": true, 5 | "auto_cast": false, 6 | "loss_scale": 0, 7 | "initial_scale_power": 16, 8 | "loss_scale_window": 1000, 9 | "hysteresis": 2, 10 | "min_loss_scale": 1 11 | }, 12 | "zero_optimization": { 13 | "stage": 2, 14 | "allgather_partitions": true, 15 | "allgather_bucket_size": 5e8, 16 | "overlap_comm": false, 17 | "reduce_scatter": true, 18 | "reduce_bucket_size": 5e8, 19 | "contiguous_gradients" : true, 20 | 21 | "stage3_max_live_parameters" : 1e9, 22 | "stage3_max_reuse_distance" : 1e9, 23 | "stage3_prefetch_bucket_size" : 5e8, 24 | "stage3_param_persistence_threshold" : 1e6, 25 | "sub_group_size" : 1e12, 26 | "elastic_checkpoint" : true, 27 | "stage3_gather_16bit_weights_on_model_save": true, 28 | "ignore_unused_parameters": true, 29 | "round_robin_gradients": true 30 | } 31 | } -------------------------------------------------------------------------------- /src/training/config/deepspeed_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "steps_per_print": 1, 3 | "gradient_clipping": 1.0, 4 | "optimizer": { 5 | "type": "AdamW", 6 | "params": { 7 | "lr": 0, 8 | "betas": [0.9, 0.999], 9 | "eps": 1e-8, 10 | "weight_decay": 1e-2 11 | } 12 | }, 13 | "scheduler": { 14 | "type": "WarmupDecayLR", 15 | "params": { 16 | "warmup_min_lr": 0, 17 | "warmup_max_lr": 2e-5, 18 | "warmup_num_steps": "auto", 19 | "warmup_type": "linear", 20 | "total_num_steps": "auto" 21 | } 22 | }, 23 | "zero_allow_untested_optimizer": true, 24 | "fp16": { 25 | "enabled": false 26 | }, 27 | "zero_optimization": { 28 | "stage": 2, 29 | "allgather_partitions": true, 30 | "allgather_bucket_size": 5e8, 31 | "overlap_comm": false, 32 | "reduce_scatter": true, 33 | "reduce_bucket_size": 5e8, 34 | "contiguous_gradients": true, 35 | "stage3_max_live_parameters": 1e9, 36 | "stage3_max_reuse_distance": 1e9, 37 | "stage3_prefetch_bucket_size": 5e8, 38 | "stage3_param_persistence_threshold": 1e6, 39 | "sub_group_size": 1e12, 40 | "elastic_checkpoint": true, 41 | "stage3_gather_16bit_weights_on_model_save": true, 42 | "ignore_unused_parameters": true, 43 | "round_robin_gradients": true, 44 | "offload_optimizer": { 45 | "device": "cpu", 46 | "pin_memory": true 47 | } 48 | } 49 | } -------------------------------------------------------------------------------- /src/training/config/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : ssbuild 3 | # @Time : 2023/5/31 14:43 4 | import json 5 | import os 6 | import torch 7 | from transformers import BitsAndBytesConfig 8 | 9 | # 全局配置 10 | global_args = { 11 | # 训练配置 12 | **dict( 13 | trainer_backend ='pl', # one of pl , hf 14 | enable_deepspeed = True, 15 | enable_ptv2 = False, 16 | enable_lora = False, 17 | load_in_bit = 0, # 4 load_in_4bit, 8 load_in_8bit other 0 18 | ), 19 | #与 transformers config合并 20 | "config_merge": { 21 | }, 22 | # qlora 23 | "quantization_config": BitsAndBytesConfig( 24 | load_in_8bit =False, 25 | load_in_4bit = False, 26 | llm_int8_threshold=6.0, 27 | llm_int8_has_fp16_weight=False, 28 | bnb_4bit_compute_dtype=torch.float16 if not torch.cuda.is_bf16_supported() else torch.bfloat16, 29 | bnb_4bit_use_double_quant=True, 30 | bnb_4bit_quant_type="nf4", 31 | ), 32 | } 33 | 34 | 35 | 36 | 37 | 38 | if global_args["enable_lora"]: 39 | from config.sft_config_lora import train_info_args,train_info_args_hf,train_model_config 40 | elif global_args["enable_ptv2"]: 41 | from config.sft_config_ptv2 import train_info_args,train_info_args_hf,train_model_config 42 | else: 43 | from config.sft_config import train_info_args,train_info_args_hf,train_model_config 44 | 45 | 46 | if global_args["trainer_backend"] == "hf": 47 | train_info_args = train_info_args_hf 48 | 49 | 50 | 51 | 52 | 53 | def patch_args(train_info_args): 54 | assert global_args["enable_lora"] + global_args["enable_ptv2"] <= 1 , ValueError("lora ptv2 cannot open at same time") 55 | 56 | if global_args['quantization_config'] is not None: 57 | global_args['quantization_config'].load_in_4bit = global_args["load_in_bit"] == 4 58 | global_args['quantization_config'].load_in_8bit = global_args["load_in_bit"] == 8 59 | if global_args["load_in_bit"] == 0: 60 | global_args["quantization_config"] = None 61 | 62 | if global_args["enable_lora"]: 63 | #检查lora adalora是否开启 64 | if 'lora' not in train_info_args and 'adalora' not in train_info_args: 65 | raise ValueError('please config lora or adalora') 66 | if train_info_args.get('lora',{}).get('with_lora',False) and train_info_args.get('adalora',{}).get('with_lora',False): 67 | raise Exception('lora and adalora can set one at same time !') 68 | 69 | train_info_args.pop('prompt', None) 70 | elif global_args["enable_ptv2"]: 71 | train_info_args.pop('lora', None) 72 | train_info_args.pop('adalora', None) 73 | if hasattr(train_info_args,"gradient_checkpointing"): 74 | train_info_args.gradient_checkpointing = False 75 | else: 76 | train_info_args.pop('lora',None) 77 | train_info_args.pop('adalora', None) 78 | train_info_args.pop('prompt', None) 79 | 80 | # 预处理 81 | if 'rwkv' in train_info_args[ 'tokenizer_name' ].lower(): 82 | train_info_args[ 'use_fast_tokenizer' ] = True 83 | 84 | 85 | 86 | patch_args(train_info_args) 87 | 88 | 89 | def get_deepspeed_config(precision='fp16'): 90 | ''' 91 | lora prompt finetuning deepspeed_offload.json 92 | 普通 finetuning deepspeed.json 93 | ''' 94 | # 是否开启deepspeed 95 | if not global_args["enable_deepspeed"]: 96 | return None 97 | precision = str(precision).lower() 98 | # 选择 deepspeed 配置文件 99 | is_need_update_config = False 100 | if global_args["enable_lora"]: 101 | is_need_update_config = True 102 | filename = os.path.join(os.path.dirname(__file__), 'deepspeed_offload.json') 103 | else: 104 | # filename = os.path.join(os.path.dirname(__file__), 'deepspeed.json') 105 | is_need_update_config = True 106 | filename = os.path.join(os.path.dirname(__file__), 'deepspeed_offload.json') 107 | 108 | 109 | with open(filename, mode='r', encoding='utf-8') as f: 110 | deepspeed_config = json.loads(f.read()) 111 | 112 | #lora offload 同步优化器配置 113 | if is_need_update_config: 114 | optimizer = deepspeed_config.get('optimizer',None) 115 | if optimizer: 116 | if global_args["trainer_backend"] == 'hf': 117 | optimizer[ 'params' ][ 'betas' ] = (train_info_args.get('adam_beta1', 0.9),train_info_args.get('adam_beta2', 0.999),) 118 | optimizer[ 'params' ][ 'lr' ] = train_info_args.get('learning_rate', 2e-5) 119 | optimizer[ 'params' ][ 'eps' ] = train_info_args.get('adam_epsilon', 1e-8) 120 | # deepspeed_offload 优化器有效 121 | train_info_args[ 'optim' ] = optimizer[ 'type' ] 122 | else: 123 | optimizer['params']['betas'] = train_info_args.get('optimizer_betas', (0.9, 0.999)) 124 | optimizer['params']['lr'] = train_info_args.get('learning_rate', 2e-5) 125 | optimizer['params']['eps'] = train_info_args.get('adam_epsilon', 1e-8) 126 | # deepspeed_offload 优化器有效 127 | train_info_args['optimizer'] = optimizer['type'] 128 | 129 | if precision == 'bf16': 130 | if 'fp16' in deepspeed_config: 131 | deepspeed_config["fp16"]["enbale"] = False 132 | if 'bf16' in deepspeed_config: 133 | deepspeed_config["bf16"]["enbale"] = True 134 | else: 135 | deepspeed_config['bf16'] = {"enbale": True} 136 | elif precision == 'fp16': 137 | if 'bf16' in deepspeed_config: 138 | deepspeed_config["bf16"]["enbale"] = False 139 | 140 | return deepspeed_config 141 | 142 | -------------------------------------------------------------------------------- /src/training/config/sft_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/5/16 10:13 3 | 4 | import json 5 | import os 6 | import torch 7 | from config.constant_map import train_model_config 8 | 9 | 10 | train_info_args = { 11 | 'devices': [0, 1, 2, 3, 4, 5, 6, 7], 12 | # 'devices': [0], 13 | 'data_backend': 'parquet', #one of record lmdb arrow_stream arrow_file,parquet, 超大数据集可以使用 lmdb , 注 lmdb 存储空间比record大 14 | 15 | # 预训练模型配置 16 | **train_model_config, 17 | 18 | 19 | 'convert_onnx': False, # 转换onnx模型 20 | 'do_train': True, 21 | # TODO change training file path 22 | 'train_file': [ './data/train.jsonl'], 23 | 'max_epochs': 5, 24 | 'max_steps': -1, 25 | 26 | # *** optimizer 27 | # lamb,adamw_hf,adamw,adamw_torch,adamw_torch_fused,adamw_torch_xla,adamw_apex_fused, 28 | # adafactor,adamw_anyprecision,sgd,adagrad,adamw_bnb_8bit,adamw_8bit,lion,lion_8bit,lion_32bit, 29 | # paged_adamw_32bit,paged_adamw_8bit,paged_lion_32bit,paged_lion_8bit, 30 | # lamb_fused_dp adagrad_cpu_dp adam_cpu_dp adam_fused_dp 31 | 32 | # *** scheduler 33 | # linear,WarmupCosine,CAWR,CAL,Step,ReduceLROnPlateau, cosine,cosine_with_restarts,polynomial, 34 | # constant,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau 35 | 36 | # 'optimizer': 'lion', 37 | # 'scheduler_type': 'CAWR', 38 | # 'scheduler':{'T_mult': 1, 39 | # 'rewarm_epoch_num': 0.5, # 如果 max_epochs is not None ! 40 | # # 'T_0': 50000, # 如果 max_epochs is None , 设定步数 41 | # 'verbose': False}, 42 | 43 | # 'scheduler_type': 'linear',# one of [linear,WarmupCosine,CAWR,CAL,Step,ReduceLROnPlateau 44 | # 'scheduler': None, 45 | 46 | # 切换scheduler类型 47 | # 'scheduler_type': 'WarmupCosine', 48 | # 'scheduler': None, 49 | 50 | # 'scheduler_type': 'ReduceLROnPlateau', 51 | # 'scheduler': None, 52 | 53 | # 'scheduler_type': 'Step', 54 | # 'scheduler':{ 'decay_rate': 0.999,'decay_steps': 100,'verbose': True}, 55 | 56 | # 'scheduler_type': 'CAWR', 57 | # 'scheduler':{'T_mult': 1, 'rewarm_epoch_num': 2, 'verbose': True}, 58 | 59 | # 'scheduler_type': 'CAL', 60 | # 'scheduler': {'rewarm_epoch_num': 2,'verbose': True}, 61 | 62 | 63 | 'optimizer_betas': (0.9, 0.999), 64 | 'train_batch_size': 4, 65 | 'eval_batch_size': 2, 66 | 'test_batch_size': 2, 67 | 'learning_rate': 0, # 68 | 'adam_epsilon': 1e-8, 69 | 'gradient_accumulation_steps': 1, 70 | 'max_grad_norm': 1.0, 71 | 'weight_decay': 0, 72 | 'warmup_steps': 0, 73 | 'output_dir': './output', 74 | 'max_seq_length': 512, # 75 | 'max_target_length': 100, # 预测最大长度, 保留字段 76 | 'use_fast_tokenizer': False, 77 | #'do_lower_case': False, 78 | "dataloader_drop_last": True, 79 | "dataloader_pin_memory":True, 80 | "dataloader_num_workers": 0, 81 | } 82 | 83 | 84 | 85 | 86 | 87 | 88 | train_info_args_hf = { 89 | 'data_backend': 'parquet', #one of record lmdb arrow_stream arrow_file,parquet, 超大数据集可以使用 lmdb , 注 lmdb 存储空间比record大 90 | # 预训练模型配置 91 | **train_model_config, 92 | 93 | "output_dir": "./outputs_hf", 94 | "overwrite_output_dir": True, 95 | "num_train_epochs": 20, 96 | "max_steps": -1, 97 | "save_safetensors": False, 98 | "save_strategy": "steps", 99 | "save_steps": 1000, 100 | "save_total_limit": 10, 101 | "seed": 66, 102 | "fp16": True, 103 | 'do_train': True, 104 | 'train_file': [ '/cjl/llm_finetuning/data/prompt_engineer/en/train.jsonl' ], 105 | 'do_eval': False, 106 | 'do_predict': False, 107 | "per_device_train_batch_size": 2, 108 | "per_device_eval_batch_size": 2, 109 | "gradient_accumulation_steps": 1, 110 | "evaluation_strategy": "no", 111 | "eval_steps": 100, 112 | "optim": "adamw_torch", 113 | "lr_scheduler_type": "cosine", # one of linear,cosine,cosine_with_restarts,polynomial,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau 114 | "torch_compile": False, 115 | "learning_rate": 2e-5, 116 | "adam_beta1": 0.9, 117 | "adam_beta2": 0.999, 118 | "adam_epsilon": 1e-8, 119 | "max_grad_norm": 1.0, 120 | "weight_decay": 0., 121 | "warmup_ratio": 0.03, 122 | "logging_strategy": "steps", 123 | "logging_steps": 10, 124 | "tf32": False, 125 | "gradient_checkpointing": True, 126 | 'max_seq_length': 512, # 127 | 'max_target_length': 100, # 预测最大长度, 保留字段 128 | 'use_fast_tokenizer': False, 129 | # 'do_lower_case': False, 130 | "dataloader_drop_last": True, 131 | "dataloader_pin_memory": True, 132 | "dataloader_num_workers": 0, 133 | 134 | "log_level": "info", # 'info', 'warning', 'error' and 'critical , passive', 135 | 136 | 137 | } 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /src/training/config/sft_config_lora.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/5/16 10:13 3 | 4 | import json 5 | import os 6 | import torch 7 | from config.constant_map import train_model_config,train_target_modules_maps 8 | 9 | 10 | # 默认禁用lora 相关模块 , lora 和 adalora 只能同时启用一个 11 | lora_info_args = { 12 | 'with_lora': True, # 是否启用lora模块 13 | 'lora_type': 'lora', 14 | 'r': 8, 15 | 'target_modules': train_target_modules_maps[train_model_config['model_type']], 16 | 'lora_alpha': 32, 17 | 'lora_dropout': 0.1, 18 | 'fan_in_fan_out': False, 19 | 'bias': 'none', # Bias type for Lora. Can be 'none', 'all' or 'lora_only'" 20 | 'modules_to_save' : None, # "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " 21 | 'layers_to_transform': None, 22 | 'layers_pattern': None, 23 | } 24 | 25 | adalora_info_args = { 26 | 'with_lora': False, # 是否启用adalora模块 27 | 'lora_type': 'adalora', 28 | 'r': 8, 29 | 'target_modules': train_target_modules_maps[train_model_config['model_type']], 30 | 'lora_alpha': 32, 31 | 'lora_dropout': 0.1, 32 | 'fan_in_fan_out': False, 33 | 'bias': 'none', # Bias type for Lora. Can be 'none', 'all' or 'lora_only'" 34 | 'modules_to_save' : None, # "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " 35 | 'layers_to_transform': None, 36 | 'layers_pattern': None, 37 | 38 | 'target_r':8, # Target Lora matrix dimension. 39 | 'init_r': 12, #Intial Lora matrix dimension. 40 | 'tinit': 0, #The steps of initial warmup. 41 | 'tfinal': 0, #The steps of final warmup. 42 | 'deltaT': 1, #Step interval of rank allocation. 43 | 'beta1': 0.85, #Hyperparameter of EMA. 44 | 'beta2': 0.85, #Hyperparameter of EMA. 45 | 'orth_reg_weight': 0.5, #The orthogonal regularization coefficient. 46 | 'total_step': None, #The total training steps. 47 | 'rank_pattern': None, #The saved rank pattern. 48 | } 49 | 50 | 51 | train_info_args = { 52 | 'devices': 1, 53 | 'data_backend': 'parquet', #one of record lmdb arrow_stream arrow_file,parquet, 超大数据集可以使用 lmdb , 注 lmdb 存储空间比record大 54 | # 预训练模型配置 55 | **train_model_config, 56 | 'convert_onnx': False, # 转换onnx模型 57 | 'do_train': True, 58 | 'train_file': [ './data/finetune_train_examples.json'], 59 | 'max_epochs': 20, 60 | 'max_steps': -1, 61 | 62 | # *** optimizer 63 | # lamb,adamw_hf,adamw,adamw_torch,adamw_torch_fused,adamw_torch_xla,adamw_apex_fused, 64 | # adafactor,adamw_anyprecision,sgd,adagrad,adamw_bnb_8bit,adamw_8bit,lion,lion_8bit,lion_32bit, 65 | # paged_adamw_32bit,paged_adamw_8bit,paged_lion_32bit,paged_lion_8bit, 66 | # lamb_fused_dp adagrad_cpu_dp adam_cpu_dp adam_fused_dp 67 | 68 | # *** scheduler 69 | # linear,WarmupCosine,CAWR,CAL,Step,ReduceLROnPlateau, cosine,cosine_with_restarts,polynomial, 70 | # constant,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau 71 | 'optimizer': 'lion', 72 | 'scheduler_type': 'CAWR', 73 | 'scheduler':{'T_mult': 1, 74 | 'rewarm_epoch_num': 0.5, # 如果 max_epochs is not None ! 75 | # 'T_0': 50000, # 如果 max_epochs is None , 设定步数 76 | 'verbose': False}, 77 | 78 | # 'scheduler_type': 'linear',# one of [linear,WarmupCosine,CAWR,CAL,Step,ReduceLROnPlateau 79 | # 'scheduler': None, 80 | 81 | # 切换scheduler类型 82 | # 'scheduler_type': 'WarmupCosine', 83 | # 'scheduler': None, 84 | 85 | # 'scheduler_type': 'ReduceLROnPlateau', 86 | # 'scheduler': None, 87 | 88 | # 'scheduler_type': 'Step', 89 | # 'scheduler':{ 'decay_rate': 0.999,'decay_steps': 100,'verbose': True}, 90 | 91 | # 'scheduler_type': 'CAWR', 92 | # 'scheduler':{'T_mult': 1, 'rewarm_epoch_num': 2, 'verbose': True}, 93 | 94 | # 'scheduler_type': 'CAL', 95 | # 'scheduler': {'rewarm_epoch_num': 2,'verbose': True}, 96 | 97 | 98 | 'optimizer_betas': (0.9, 0.999), 99 | 'train_batch_size': 2, 100 | 'eval_batch_size': 2, 101 | 'test_batch_size': 2, 102 | 'learning_rate': 2e-4, # 103 | 'adam_epsilon': 1e-8, 104 | 'gradient_accumulation_steps': 1, 105 | 'max_grad_norm': 1.0, 106 | 'weight_decay': 0, 107 | 'warmup_steps': 0, 108 | 'output_dir': './output', 109 | 'max_seq_length': 512, # 110 | 'max_target_length': 100, # 预测最大长度, 保留字段 111 | 'use_fast_tokenizer': False, 112 | #'do_lower_case': False, 113 | 114 | ############## lora模块 115 | 'lora': lora_info_args, 116 | 'adalora': adalora_info_args, 117 | "dataloader_drop_last": True, 118 | "dataloader_pin_memory": True, 119 | "dataloader_num_workers": 0, 120 | 121 | } 122 | 123 | 124 | 125 | 126 | train_info_args_hf = { 127 | 'data_backend': 'parquet', 128 | # one of record lmdb arrow_stream arrow_file,parquet, 超大数据集可以使用 lmdb , 注 lmdb 存储空间比record大 129 | # 预训练模型配置 130 | **train_model_config, 131 | 132 | "output_dir": "./outputs_hf", 133 | "overwrite_output_dir": True, 134 | "num_train_epochs": 20, 135 | "max_steps": -1, 136 | "save_safetensors": False, 137 | "save_strategy": "steps", 138 | "save_steps": 1000, 139 | "save_total_limit": 10, 140 | "seed": 42, 141 | "fp16": True, 142 | 'do_train': True, 143 | 'train_file': [ './data/finetune_train_examples.json'], 144 | 'do_eval': False, 145 | 'do_predict': False, 146 | "per_device_train_batch_size": 2, 147 | "per_device_eval_batch_size": 2, 148 | "gradient_accumulation_steps": 1, 149 | "evaluation_strategy": "no", 150 | "eval_steps": 100, 151 | "optim": "adamw_torch", 152 | "lr_scheduler_type": "cosine", # one of linear,cosine,cosine_with_restarts,polynomial,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau 153 | "torch_compile": False, 154 | "learning_rate": 2e-5, 155 | "adam_beta1": 0.9, 156 | "adam_beta2": 0.999, 157 | "adam_epsilon": 1e-8, 158 | "max_grad_norm": 1.0, 159 | "weight_decay": 0., 160 | "warmup_ratio": 0.03, 161 | "logging_strategy": "steps", 162 | "logging_steps": 10, 163 | "tf32": False, 164 | "gradient_checkpointing": True, 165 | 'max_seq_length': 512, # 166 | 'max_target_length': 100, # 预测最大长度, 保留字段 167 | 'use_fast_tokenizer': False, 168 | # 'do_lower_case': False, 169 | "dataloader_drop_last": True, 170 | "dataloader_pin_memory": True, 171 | "dataloader_num_workers": 0, 172 | 173 | "log_level": "info", # 'info', 'warning', 'error' and 'critical , passive', 174 | ############## lora模块 175 | 'lora': lora_info_args, 176 | 'adalora': adalora_info_args, 177 | 178 | } 179 | 180 | 181 | -------------------------------------------------------------------------------- /src/training/config/sft_config_ptv2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/5/16 10:13 3 | 4 | import json 5 | import os 6 | 7 | from config.constant_map import train_model_config 8 | 9 | 10 | 11 | prompt_info_args = { 12 | "with_prompt": True, 13 | "prompt_type": "prefix_tuning", # one of prompt_tuning,p_tuning,prefix_tuning,adaption_prompt 14 | "task_type": "causal_lm", # one of seq_cls,seq_2_seq_lm,causal_lm,token_cls 15 | "prefix_projection": False, # Whether to project the prefix tokens" 16 | "num_virtual_tokens": 32, # Number of virtual tokens 17 | # "token_dim": 2048, # The hidden embedding dimension of the base transformer model. 18 | # "num_transformer_submodules": 1, # The number of transformer submodules in the base transformer model. 19 | # "num_attention_heads" : 24, # The number of attention heads in the base transformer model. 20 | # "num_layers": 1, # The number of layers in the base transformer model. 21 | # "encoder_hidden_size": 2048, # The hidden size of the encoder 22 | } 23 | 24 | train_info_args = { 25 | 'devices': 1, 26 | 'data_backend': 'parquet', #one of record lmdb arrow_stream arrow_file,parquet, 超大数据集可以使用 lmdb , 注 lmdb 存储空间比record大 27 | # 预训练模型配置 28 | **train_model_config, 29 | 30 | 'convert_onnx': False, # 转换onnx模型 31 | 'do_train': True, 32 | 'train_file': [ './data/finetune_train_examples.json'], 33 | 'max_epochs': 20, 34 | 'max_steps': -1, 35 | 36 | # *** optimizer 37 | # lamb,adamw_hf,adamw,adamw_torch,adamw_torch_fused,adamw_torch_xla,adamw_apex_fused, 38 | # adafactor,adamw_anyprecision,sgd,adagrad,adamw_bnb_8bit,adamw_8bit,lion,lion_8bit,lion_32bit, 39 | # paged_adamw_32bit,paged_adamw_8bit,paged_lion_32bit,paged_lion_8bit, 40 | # lamb_fused_dp adagrad_cpu_dp adam_cpu_dp adam_fused_dp 41 | 42 | # *** scheduler 43 | # linear,WarmupCosine,CAWR,CAL,Step,ReduceLROnPlateau, cosine,cosine_with_restarts,polynomial, 44 | # constant,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau 45 | 46 | 'optimizer': 'lion', 47 | 'scheduler_type': 'CAWR', 48 | 'scheduler':{'T_mult': 1,'rewarm_epoch_num': 0.5, 49 | # 如果 max_epochs is not None ! 50 | # 'T_0': 50000, # 如果 max_epochs is None , 设定步数 51 | 'verbose': False}, 52 | 53 | # 'scheduler_type': 'linear',# one of [linear,WarmupCosine,CAWR,CAL,Step,ReduceLROnPlateau 54 | # 'scheduler': None, 55 | 56 | # 切换scheduler类型 57 | # 'scheduler_type': 'WarmupCosine', 58 | # 'scheduler': None, 59 | 60 | # 'scheduler_type': 'ReduceLROnPlateau', 61 | # 'scheduler': None, 62 | 63 | # 'scheduler_type': 'Step', 64 | # 'scheduler':{ 'decay_rate': 0.999,'decay_steps': 100,'verbose': True}, 65 | 66 | # 'scheduler_type': 'CAWR', 67 | # 'scheduler':{'T_mult': 1, 'rewarm_epoch_num': 2, 'verbose': True}, 68 | 69 | # 'scheduler_type': 'CAL', 70 | # 'scheduler': {'rewarm_epoch_num': 2,'verbose': True}, 71 | 72 | 73 | 'optimizer_betas': (0.9, 0.999), 74 | 'train_batch_size': 2, 75 | 'eval_batch_size': 2, 76 | 'test_batch_size': 2, 77 | 'learning_rate': 5e-4, # 78 | 'adam_epsilon': 1e-8, 79 | 'gradient_accumulation_steps': 1, 80 | 'max_grad_norm': 1.0, 81 | 'weight_decay': 0, 82 | 'warmup_steps': 0, 83 | 'output_dir': './output', 84 | 'max_seq_length': 512, # 85 | 'max_target_length': 100, # 预测最大长度, 保留字段 86 | 'use_fast_tokenizer': False, 87 | #'do_lower_case': False, 88 | "dataloader_drop_last": True, 89 | "dataloader_pin_memory": True, 90 | "dataloader_num_workers": 0, 91 | 92 | ############## lora模块 93 | 'prompt': prompt_info_args, 94 | 95 | } 96 | 97 | 98 | 99 | 100 | 101 | 102 | train_info_args_hf = { 103 | 'data_backend': 'parquet', 104 | # one of record lmdb arrow_stream arrow_file,parquet, 超大数据集可以使用 lmdb , 注 lmdb 存储空间比record大 105 | # 预训练模型配置 106 | **train_model_config, 107 | 108 | "output_dir": "./outputs_hf", 109 | "overwrite_output_dir": True, 110 | "num_train_epochs": 20, 111 | "max_steps": -1, 112 | "save_safetensors": False, 113 | "save_strategy": "steps", 114 | "save_steps": 1000, 115 | "save_total_limit": 10, 116 | "seed": 42, 117 | "fp16": True, 118 | 'do_train': True, 119 | 'train_file': [ './data/finetune_train_examples.json'], 120 | 'do_eval': False, 121 | 'do_predict': False, 122 | "per_device_train_batch_size": 2, 123 | "per_device_eval_batch_size": 2, 124 | "gradient_accumulation_steps": 1, 125 | "evaluation_strategy": "no", 126 | "eval_steps": 100, 127 | "optim": "adamw_torch", 128 | "lr_scheduler_type": "cosine",# one of linear,cosine,cosine_with_restarts,polynomial,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau 129 | "torch_compile": False, 130 | "learning_rate": 2e-5, 131 | "adam_beta1": 0.9, 132 | "adam_beta2": 0.999, 133 | "adam_epsilon": 1e-8, 134 | "max_grad_norm": 1.0, 135 | "weight_decay": 0., 136 | "warmup_ratio": 0.03, 137 | "logging_strategy": "steps", 138 | "logging_steps": 10, 139 | "tf32": False, 140 | "gradient_checkpointing": True, 141 | 'max_seq_length': 512, # 142 | 'max_target_length': 100, # 预测最大长度, 保留字段 143 | 'use_fast_tokenizer': False, 144 | # 'do_lower_case': False, 145 | "dataloader_drop_last": True, 146 | "dataloader_pin_memory": True, 147 | "dataloader_num_workers": 0, 148 | 149 | "log_level": "info", # 'info', 'warning', 'error' and 'critical , passive', 150 | ############## lora模块 151 | 'prompt': prompt_info_args, 152 | } 153 | 154 | -------------------------------------------------------------------------------- /src/training/data/finetune_train_examples.jsonl: -------------------------------------------------------------------------------- 1 | {"id": 0, "paragraph": [{"q": "Make me a cup of tea.", "a": "Please provide me with instructions on how to make a cup of tea."}]} 2 | {"id": 1, "paragraph": [{"q": "Give me 5 first date ideas", "a": "Provide 5 first date ideas with reasons for each suggestion."}]} 3 | -------------------------------------------------------------------------------- /src/training/data_processer.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/3/25 18:36 2 | # @Author : tk 3 | import copy 4 | from enum import Enum 5 | import numpy as np 6 | from transformers import PreTrainedTokenizer 7 | 8 | DEFAULT_PAD_TOKEN = "[PAD]" 9 | DEFAULT_EOS_TOKEN = "" 10 | DEFAULT_BOS_TOKEN = "" 11 | DEFAULT_UNK_TOKEN = "" 12 | 13 | class DataStrategy(Enum): 14 | tunction = 1 15 | slidding = 2 16 | 17 | 18 | 19 | def build_template_llama(query, answer = None,prefix=None, history=None): 20 | return query 21 | 22 | 23 | def build_template_default(query, answer = None,prefix=None, history=None): 24 | prompt = "[INST] You are an expert prompt engineer. Please help me improve this prompt to get a more helpful and harmless response:\n{} [/INST]".format(query) 25 | return prompt 26 | 27 | def build_template_tiger(query,answer = None,prefix=None, history=None): 28 | prompt = prefix or '' 29 | tok_ins = "\n\n### Instruction:\n" 30 | tok_res = "\n\n### Response:\n" 31 | if history is not None: 32 | for q,a in history: 33 | prompt += "{}{}{}{}".format(tok_ins,q,tok_res,a) 34 | 35 | prompt += "{}{}{}".format(tok_ins, query, tok_res) 36 | if answer is not None: 37 | prompt += answer 38 | return prompt 39 | 40 | 41 | #切换模板 42 | build_template = build_template_default 43 | # build_template = build_template_llama 44 | 45 | 46 | class TokenIdsMaker: 47 | @classmethod 48 | def final(cls, tokenizer, input_ids, labels, max_seq_length): 49 | seqlen = np.asarray(len(input_ids), dtype=np.int32) 50 | pad_len = max_seq_length - seqlen 51 | input_ids = np.asarray(input_ids, dtype=np.int32) 52 | attention_mask = np.asarray([1] * len(input_ids), dtype=np.int32) 53 | labels = np.asarray(labels, dtype=np.int32) 54 | if pad_len: 55 | pad_val = tokenizer.eos_token_id 56 | input_ids = np.pad(input_ids, (0, pad_len), 'constant', constant_values=(pad_val, pad_val)) 57 | attention_mask = np.pad(attention_mask, (0, pad_len), 'constant', constant_values=(0, 0)) 58 | labels = np.pad(labels, (0, pad_len), 'constant', constant_values=(-100, -100)) 59 | d = { 60 | 'input_ids': input_ids, 61 | 'attention_mask': attention_mask, 62 | 'labels': labels, 63 | 'seqlen': seqlen 64 | } 65 | return d 66 | @classmethod 67 | def tunction(cls, tokenizer: PreTrainedTokenizer, config, sup, max_seq_length, examples): 68 | sptoken = [config.bos_token_id] 69 | ds = [] 70 | prefix, examples = examples 71 | max_a_b_len = 0 72 | for sid, (q, a) in enumerate(examples): 73 | a_ids = tokenizer.encode(text=build_template(q,prefix=prefix,history=examples[:sid]), add_special_tokens=False) 74 | # from IPython import embed 75 | # embed() 76 | b_ids = tokenizer.encode(text=a, add_special_tokens=False) 77 | max_a_b_len = max(max_a_b_len, len(a_ids) + len(b_ids) + len(sptoken) + 1) 78 | while len(a_ids) + len(b_ids) > max_seq_length - len(sptoken) - 1: 79 | if len(b_ids) > len(a_ids): 80 | b_ids.pop(-1) 81 | else: 82 | a_ids.pop(0) 83 | b_ids += [config.eos_token_id] 84 | input_ids = a_ids + b_ids 85 | labels = copy.deepcopy(input_ids) if not sup else [-100] * len(a_ids) + copy.deepcopy(b_ids) 86 | input_ids = sptoken + input_ids 87 | labels = sptoken + labels if not sup else [-100] * len(sptoken) + labels 88 | assert len(input_ids) <= max_seq_length 89 | ds.append(cls.final(tokenizer, input_ids, labels, max_seq_length)) 90 | return ds, max_a_b_len 91 | 92 | 93 | @classmethod 94 | def slidding(cls, tokenizer: PreTrainedTokenizer,config,stride,max_seq_length, examples, 95 | sliding_size=None, 96 | src_max_length=-1, 97 | dst_max_length=-1, 98 | sup=True): 99 | sptoken = [config.bos_token_id] 100 | if sliding_size is None or sliding_size < 0: 101 | sliding_size = max_seq_length - len(sptoken) 102 | 103 | assert sliding_size <= max_seq_length - len(sptoken) 104 | 105 | ds = [] 106 | prefix, examples = examples 107 | for sid, (q, a) in enumerate(examples): 108 | a_ids = tokenizer.encode(text=build_template(q, prefix=prefix, history=examples[:sid]),add_special_tokens=False) 109 | b_ids = tokenizer.encode(text=a, add_special_tokens=False) 110 | if src_max_length and src_max_length > 0: 111 | a_ids = a_ids[:src_max_length] 112 | if dst_max_length and dst_max_length > 0: 113 | b_ids = b_ids[:dst_max_length] 114 | 115 | b_ids += [config.eos_token_id] 116 | input_ids_qa = a_ids + b_ids 117 | labels_all = copy.deepcopy(input_ids_qa) if not sup else [-100] * len(a_ids) + b_ids 118 | 119 | pos = 0 120 | while pos < len(input_ids_qa): 121 | input_ids = input_ids_qa[pos:pos + max_seq_length - len(sptoken)] 122 | labels = labels_all[pos:pos + max_seq_length - len(sptoken)] 123 | 124 | pos += sliding_size 125 | if np.all(np.asarray(labels) == -100): 126 | continue 127 | 128 | input_ids = sptoken + input_ids 129 | labels = sptoken + labels if not sup else [-100] * len(sptoken) + labels 130 | ds.append(cls.final(tokenizer, input_ids, labels, max_seq_length)) 131 | return ds 132 | 133 | 134 | -------------------------------------------------------------------------------- /src/training/data_utils.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/1/22 16:22 2 | # @Author : tk 3 | # @FileName: data_utils.py 4 | 5 | import copy 6 | import json 7 | import os 8 | import random 9 | import typing 10 | import numpy as np 11 | import torch 12 | from deep_training.data_helper import DataHelper, ModelArguments, TrainingArguments,TrainingArgumentsHF, DataArguments 13 | from aigc_zoo.model_zoo.llm.llm_model import PetlArguments,LoraConfig,PromptArguments 14 | from fastdatasets.record import load_dataset as Loader, RECORD, WriterObject, gfile 15 | from transformers import PreTrainedTokenizer, HfArgumentParser, PretrainedConfig 16 | from data_processer import DataStrategy, TokenIdsMaker, build_template, DEFAULT_PAD_TOKEN, DEFAULT_EOS_TOKEN, \ 17 | DEFAULT_BOS_TOKEN, DEFAULT_UNK_TOKEN 18 | from config import * 19 | from module_setup import module_setup 20 | 21 | 22 | module_setup() 23 | 24 | data_conf = { 25 | 'strategy': DataStrategy.tunction, # 数据策略选项 26 | DataStrategy.tunction: { 27 | 'sup': True, # 是否监督模式 28 | }, 29 | 30 | DataStrategy.slidding: { 31 | 'stride': int(train_info_args['max_seq_length'] / 3 * 2), 32 | 'sup': True, # 是否监督模式 33 | "src_max_length": train_info_args['max_seq_length'] - 10, 34 | "dst_max_length": None, 35 | } 36 | 37 | } 38 | 39 | 40 | 41 | def preprocess(text): 42 | return text 43 | 44 | def postprocess(text): 45 | return text 46 | 47 | 48 | class NN_DataHelper(DataHelper): 49 | index = 1 50 | data_len = [] 51 | 52 | def __init__(self, *args, **kwargs): 53 | super(NN_DataHelper, self).__init__(*args, **kwargs) 54 | assert data_conf[DataStrategy.slidding]['stride'] > 0 55 | 56 | def load_tokenizer_and_config(self, *args, **kwargs): 57 | ret = super().load_tokenizer_and_config(*args, **kwargs) 58 | self._preprocess_tokenizer_config() 59 | return ret 60 | 61 | def _preprocess_tokenizer_config(self): 62 | model_args = self.model_args 63 | tokenizer = self.tokenizer 64 | config = self.config 65 | 66 | 67 | 68 | if "llama" in model_args.model_type.lower(): 69 | special_tokens_dict = dict() 70 | # from IPython import embed 71 | # embed() 72 | # if tokenizer.pad_token is None: 73 | # special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 74 | if tokenizer.eos_token is None: 75 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 76 | if tokenizer.bos_token is None: 77 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 78 | if tokenizer.unk_token is None: 79 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 80 | 81 | _ = tokenizer.add_special_tokens(special_tokens_dict) 82 | 83 | # tokenizer.add_special_tokens({ 84 | # "eos_token": DEFAULT_EOS_TOKEN, 85 | # "bos_token": DEFAULT_BOS_TOKEN, 86 | # "unk_token": DEFAULT_UNK_TOKEN, 87 | # }) 88 | # if tokenizer.pad_token_id is None or tokenizer.pad_token_id == -1: 89 | # tokenizer.pad_token_id = tokenizer.eos_token_id 90 | 91 | if tokenizer.pad_token is None: 92 | tokenizer.add_special_tokens({ 93 | "pad_token": tokenizer.eos_token, 94 | }) 95 | if config.pad_token_id is None or config.pad_token_id == -1: 96 | config.pad_token_id = tokenizer.eos_token_id 97 | 98 | 99 | 100 | if config.decoder_start_token_id is None: 101 | config.decoder_start_token_id = config.bos_token_id 102 | 103 | if config.decoder_start_token_id != tokenizer.bos_token_id: 104 | print('*' * 30, 'config.decoder_start_token_id != tokenizer.bos_token_id !!!') 105 | 106 | assert config.decoder_start_token_id == config.bos_token_id 107 | 108 | def on_data_ready(self): 109 | self.index = -1 110 | 111 | # 切分词 112 | def on_data_process(self, data: typing.Any, mode: str): 113 | self.index += 1 114 | 115 | tokenizer: PreTrainedTokenizer 116 | config = self.config 117 | max_seq_length = self.max_seq_length_dict[mode] 118 | tokenizer = self.tokenizer 119 | 120 | examples = data 121 | # from IPython import embed 122 | # embed() 123 | # exit() 124 | 125 | strategy = data_conf['strategy'] 126 | if strategy == DataStrategy.tunction: 127 | ds,l = TokenIdsMaker.tunction(tokenizer, config=config, max_seq_length=max_seq_length, examples=examples, 128 | **data_conf[strategy]) 129 | self.data_len.append(l) 130 | elif strategy == DataStrategy.slidding: 131 | ds = TokenIdsMaker.slidding(tokenizer, config=config, max_seq_length=max_seq_length, examples=examples, 132 | **data_conf[strategy]) 133 | 134 | else: 135 | raise ValueError('Invalid strategy', strategy) 136 | if not ds: 137 | return None 138 | 139 | if self.index < 3: 140 | print(ds[0]) 141 | # from IPython import embed 142 | # embed() 143 | # exit() 144 | return ds 145 | 146 | def _get_paragraph(self,lines): 147 | D = [] 148 | for line_id, line in enumerate(lines): 149 | jd = json.loads(line) 150 | if not jd: 151 | continue 152 | paragraph = jd['paragraph'] 153 | if line_id < 10: 154 | print(paragraph) 155 | 156 | prefix = jd.get('p', '') 157 | paragraph = [(preprocess(session['q']), 158 | preprocess('\n'.join(session['a'])) if isinstance(session['a'], list) else preprocess( 159 | session['a'])) 160 | for session in paragraph] 161 | sub = [] 162 | # 自行做模板 163 | # TODO: make a template for llama2 164 | # https://gpus.llm-utils.org/llama-2-prompt-template/ 165 | for (q,a) in paragraph: 166 | if not len(a): 167 | continue 168 | assert len(a), ValueError('answer cannot empty') 169 | sub.append((q, a)) 170 | D.append((prefix, copy.deepcopy(sub))) 171 | # from IPython import embed 172 | # embed() 173 | # exit() 174 | 175 | sub.clear() 176 | return D 177 | 178 | def _get_messages(self,lines): 179 | D = [] 180 | for line_id, line in enumerate(lines): 181 | jd = json.loads(line) 182 | if not jd: 183 | continue 184 | conversations = jd['conversations'] 185 | if line_id < 10: 186 | print(conversations) 187 | 188 | paragraph = [] 189 | prefix = '' 190 | pair = [None,None] 191 | for m in conversations: 192 | if m["from"] == 'user': 193 | pair[0] = preprocess(m["value"]) 194 | elif m["from"] == 'assistant': 195 | pair[1] = preprocess(m["value"]) 196 | elif m["from"] == 'system': 197 | prefix = preprocess(m["value"]) 198 | if pair[0] is not None and pair[1] is not None: 199 | paragraph.append(tuple(pair)) 200 | pair[0],pair[1] = None,None 201 | 202 | sub = [] 203 | # 自行做模板 204 | for (q, a) in paragraph: 205 | assert len(a), ValueError('answer cannot empty') 206 | sub.append((q, a)) 207 | D.append((prefix, copy.deepcopy(sub))) 208 | sub.clear() 209 | return D 210 | # 读取文件 211 | def on_get_corpus(self, files: typing.List, mode: str): 212 | D = [] 213 | for file in files: 214 | with open(file, mode='r', encoding='utf-8', newline='\n') as f: 215 | lines = f.readlines() 216 | is_new = False 217 | if len(lines) > 0: 218 | is_new = 'conversations' in json.loads(lines[0]) 219 | if is_new: 220 | D.extend(self._get_messages(lines)) 221 | else: 222 | D.extend(self._get_paragraph(lines)) 223 | return D 224 | 225 | def collate_fn(self, batch): 226 | o = {} 227 | for i, b in enumerate(batch): 228 | if i == 0: 229 | for k in b: 230 | o[k] = [torch.tensor(b[k])] 231 | else: 232 | for k in b: 233 | o[k].append(torch.tensor(b[k])) 234 | for k in o: 235 | o[k] = torch.stack(o[k]) 236 | 237 | maxlen = torch.max(o.pop('seqlen')) 238 | o['input_ids'] = o['input_ids'][:, :maxlen] 239 | o['attention_mask'] = o['attention_mask'][:, :maxlen] 240 | o['labels'] = o['labels'][:, :maxlen].long() 241 | return o 242 | 243 | def make_dataset_all(self): 244 | data_args = self.data_args 245 | # schema for arrow parquet 246 | schema = { 247 | "input_ids": "int32_list", 248 | "attention_mask": "int32_list", 249 | "labels": "int32_list", 250 | "seqlen": "int32_list", 251 | } 252 | # 缓存数据集 253 | if data_args.do_train: 254 | self.make_dataset_with_args(data_args.train_file, mixed_data=False, shuffle=True, mode='train', 255 | schema=schema) 256 | if data_args.do_eval: 257 | self.make_dataset_with_args(data_args.eval_file, mode='eval', schema=schema) 258 | if data_args.do_test: 259 | self.make_dataset_with_args(data_args.test_file, mode='test', schema=schema) 260 | 261 | 262 | 263 | if __name__ == '__main__': 264 | 265 | if global_args["trainer_backend"] == "hf": 266 | parser = HfArgumentParser((ModelArguments, TrainingArgumentsHF, DataArguments, PetlArguments, PromptArguments), 267 | conflict_handler='resolve') 268 | model_args, training_args, data_args, lora_args, prompt_args = parser.parse_dict(train_info_args, 269 | allow_extra_keys=True, ) 270 | else: 271 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, PetlArguments, PromptArguments)) 272 | model_args, training_args, data_args, _, _ = parser.parse_dict(train_info_args) 273 | 274 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 275 | tokenizer, config, _, _ = dataHelper.load_tokenizer_and_config(config_kwargs={"torch_dtype": torch.float16}) 276 | 277 | 278 | 279 | 280 | # 缓存数据集 281 | # 检测是否存在 output/dataset_0-train.record ,不存在则制作数据集 282 | dataHelper.make_dataset_all() 283 | print(np.mean(dataHelper.data_len)) 284 | print(np.percentile(dataHelper.data_len, [25, 50, 99])) 285 | print(np.max(dataHelper.data_len)) 286 | 287 | 288 | # def shuffle_records(record_filenames, outfile, compression_type='GZIP'): 289 | # print('shuffle_records record...') 290 | # options = RECORD.TFRecordOptions(compression_type=compression_type) 291 | # dataset_reader = Loader.RandomDataset(record_filenames, options=options, with_share_memory=True) 292 | # data_size = len(dataset_reader) 293 | # all_example = [] 294 | # for i in tqdm(range(data_size), desc='load records'): 295 | # serialized = dataset_reader[i] 296 | # all_example.append(serialized) 297 | # dataset_reader.close() 298 | # 299 | # shuffle_idx = list(range(data_size)) 300 | # random.shuffle(shuffle_idx) 301 | # writer = WriterObject(outfile, options=options) 302 | # for i in tqdm(shuffle_idx, desc='shuffle record'): 303 | # example = all_example[i] 304 | # writer.write(example) 305 | # writer.close() 306 | # 307 | # 308 | # # 对每个record 再次打乱 309 | # for filename in dataHelper.train_files: 310 | # shuffle_records(filename, filename) 311 | -------------------------------------------------------------------------------- /src/training/infer.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/4/2 22:49 2 | # @Author : tk 3 | # @FileName: infer 4 | 5 | import torch 6 | from deep_training.data_helper import ModelArguments 7 | from transformers import HfArgumentParser 8 | from data_utils import train_info_args, NN_DataHelper, get_deepspeed_config 9 | from aigc_zoo.model_zoo.llm.llm_model import MyTransformer 10 | from aigc_zoo.utils.llm_generate import Generate 11 | from aigc_zoo.model_zoo.llm.llm_model import RotaryNtkScaledArguments,RotaryLinearScaledArguments # aigc-zoo 0.1.20 12 | 13 | deep_config = get_deepspeed_config() 14 | 15 | 16 | def infer_tiger(model,tokenizer,max_input_length=512): 17 | tok_ins = "\n\n### Instruction:\n" 18 | tok_res = "\n\n### Response:\n" 19 | prompt_input = tok_ins + "{instruction}" + tok_res 20 | 21 | generation_config = { 22 | "do_sample": True, 23 | "eos_token_id": 2, 24 | "max_length": max_input_length, 25 | "pad_token_id": 60514, 26 | "repetition_penalty": 1.1, 27 | "temperature": 0.3, 28 | "transformers_version": "4.31.0" 29 | } 30 | text_list = ["写一个诗歌,关于冬天", 31 | "晚上睡不着应该怎么办", 32 | "从南京到上海的路线", 33 | ] 34 | 35 | for input in text_list: 36 | sess_text = '' 37 | 38 | query_text = input.strip() 39 | sess_text += tok_ins + query_text 40 | input_text = prompt_input.format_map({'instruction': sess_text.split(tok_ins, 1)[1]}) 41 | inputs = tokenizer(input_text, return_tensors='pt', truncation=True, max_length=max_input_length) 42 | inputs = {k: v.to(model.device) for k, v in inputs.items()} 43 | output = model.generate(**inputs, **generation_config) 44 | output_str = tokenizer.decode(output[0], skip_special_tokens=False, spaces_between_special_tokens=False) 45 | answer = output_str.rsplit(tok_res, 1)[1].strip() 46 | if answer.endswith(tokenizer.eos_token): 47 | answer = answer.rsplit(tokenizer.eos_token, 1)[0].strip() 48 | 49 | print('input', input) 50 | print('output', answer) 51 | 52 | if __name__ == '__main__': 53 | 54 | 55 | parser = HfArgumentParser((ModelArguments,)) 56 | (model_args,) = parser.parse_dict(train_info_args, allow_extra_keys=True) 57 | 58 | dataHelper = NN_DataHelper(model_args) 59 | tokenizer, config, _,_= dataHelper.load_tokenizer_and_config() 60 | 61 | enable_ntk = False 62 | rope_args = None 63 | if enable_ntk and config.model_type == 'llama': 64 | rope_args = RotaryNtkScaledArguments(name='rotary_emb',max_position_embeddings=2048, alpha=4) # 扩展 8k 65 | # rope_args = RotaryLinearScaledArguments(name='rotary_emb',max_position_embeddings=2048, scale=4) # 扩展 8k 66 | 67 | 68 | pl_model = MyTransformer(config=config, model_args=model_args,torch_dtype=config.torch_dtype,rope_args=rope_args) 69 | model = pl_model.get_llm_model() 70 | model = model.eval() 71 | if hasattr(model,'quantize'): 72 | # 支持llama llama2量化 73 | if not model.quantized: 74 | # 按需修改,目前只支持 4/8 bit 量化 , 可以保存量化模型 75 | model.half().quantize(4).cuda() 76 | # 保存量化权重 77 | # model.save_pretrained('llama2-7b-chat-int4',max_shard_size="2GB") 78 | # exit(0) 79 | else: 80 | # 已经量化 81 | model.half().cuda() 82 | else: 83 | model.half().cuda() 84 | 85 | if train_info_args['model_name_or_path'].lower().find('tiger') >=0: 86 | infer_tiger(model,tokenizer) 87 | else: 88 | text_list = ["写一个诗歌,关于冬天", 89 | "晚上睡不着应该怎么办", 90 | "从南京到上海的路线", 91 | ] 92 | for input in text_list: 93 | response = Generate.generate(model, query=input, tokenizer=tokenizer, max_length=512, 94 | eos_token_id=config.eos_token_id, 95 | do_sample=False, top_p=0.7, temperature=0.95, ) 96 | print('input', input) 97 | print('output', response) -------------------------------------------------------------------------------- /src/training/infer_finetuning.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/4/2 22:49 2 | # @Author : tk 3 | # @FileName: infer 4 | 5 | import torch 6 | from deep_training.data_helper import ModelArguments 7 | from transformers import HfArgumentParser, AutoConfig 8 | from data_utils import train_info_args, NN_DataHelper, get_deepspeed_config,build_template 9 | from aigc_zoo.model_zoo.llm.llm_model import MyTransformer 10 | from aigc_zoo.utils.llm_generate import Generate 11 | import json 12 | from tqdm import tqdm 13 | 14 | deep_config = get_deepspeed_config() 15 | 16 | 17 | if __name__ == '__main__': 18 | # TODO add input file and output file (requires json) 19 | # or you could implement it yourself 20 | input_file = '../../data/data_construction/examples_ctx.json' 21 | output_file = '../../data/data_construction/examples_ctx_optimized_gen.json' 22 | 23 | # optimized on evaluation set 24 | # input_file = '../../data/testset/dolly_eval.json' 25 | # output_file = '../../data/testset/dolly_eval_optimized.json' 26 | 27 | parser = HfArgumentParser((ModelArguments,)) 28 | (model_args,) = parser.parse_dict(train_info_args, allow_extra_keys=True) 29 | 30 | dataHelper = NN_DataHelper(model_args) 31 | tokenizer, _, _,_= dataHelper.load_tokenizer_and_config() 32 | 33 | 34 | config = AutoConfig.from_pretrained('./output/best_ckpt') 35 | pl_model = MyTransformer(config=config, model_args=model_args,torch_dtype=config.torch_dtype,) 36 | 37 | # deepspeed 权重使用转换脚本命令 38 | # 一般根据时间排序选最新的权重文件夹 39 | # cd best_ckpt/last 40 | # python zero_to_fp32.py . ../last.ckpt 41 | 42 | train_weight = './output/best_ckpt' 43 | 44 | pl_model.load_sft_weight(train_weight,strict=True) 45 | 46 | # 保存hf权重 47 | # config.save_pretrained('convert/') 48 | 49 | # 保存sft p-tuning-v2 权重 50 | # pl_model.save_sft_weight('convert/pytorch_model_sft_ptv2.bin') 51 | 52 | # 保存sft权重 53 | # pl_model.save_sft_weight('convert/pytorch_model_sft.bin') 54 | 55 | model = pl_model.get_llm_model() 56 | 57 | model.eval().half().cuda() 58 | 59 | 60 | with open(input_file, encoding='utf-8') as f: 61 | text_list = json.load(f)[:] 62 | 63 | gen_res = [] 64 | 65 | for input in tqdm(text_list[:]): 66 | 67 | response = Generate.generate(model, query=build_template((input['instruction']+'\n'+input['context']).strip()), tokenizer=tokenizer, max_new_tokens=1024, 68 | eos_token_id=config.eos_token_id, 69 | do_sample=True, top_p=0.9, temperature=0.6, num_beams=1) 70 | 71 | input['gen_res'] = response.strip() 72 | gen_res.append(input) 73 | 74 | with open(output_file, 'w', encoding='utf-8') as f: 75 | json.dump(gen_res, f, indent=4, ensure_ascii=False) 76 | -------------------------------------------------------------------------------- /src/training/infer_lora_finetuning.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/4/2 22:49 2 | # @Author : tk 3 | # @FileName: infer_lora_finetuning 4 | import os 5 | import torch 6 | from deep_training.data_helper import ModelArguments 7 | from transformers import HfArgumentParser,AutoConfig 8 | from data_utils import train_info_args, NN_DataHelper,global_args,build_template 9 | from aigc_zoo.model_zoo.llm.llm_model import MyTransformer,PetlArguments,PromptArguments 10 | from aigc_zoo.utils.llm_generate import Generate 11 | 12 | 13 | if __name__ == '__main__': 14 | train_info_args['seed'] = None 15 | parser = HfArgumentParser((ModelArguments,)) 16 | (model_args,) = parser.parse_dict(train_info_args, allow_extra_keys=True) 17 | 18 | 19 | dataHelper = NN_DataHelper(model_args) 20 | tokenizer, _, _, _ = dataHelper.load_tokenizer_and_config() 21 | 22 | 23 | # 一般根据时间排序选最新的权重文件夹 24 | ckpt_dir = './best_ckpt/last' 25 | 26 | config = AutoConfig.from_pretrained(ckpt_dir) 27 | lora_args = PetlArguments.from_pretrained(ckpt_dir) 28 | 29 | assert lora_args.inference_mode == True 30 | 31 | new_num_tokens = config.vocab_size 32 | if config.task_specific_params is not None and config.task_specific_params.get('vocab_size',None) is not None: 33 | config.vocab_size = config.task_specific_params['vocab_size'] 34 | 35 | pl_model = MyTransformer(config=config, model_args=model_args, 36 | lora_args=lora_args, 37 | torch_dtype=config.torch_dtype, 38 | new_num_tokens=new_num_tokens, 39 | # load_in_8bit=global_args["load_in_8bit"], 40 | # # device_map="auto", 41 | # device_map = {"":0} # 第一块卡 42 | ) 43 | 44 | # 加载lora权重 45 | pl_model.load_sft_weight(ckpt_dir) 46 | 47 | pl_model.eval().half().cuda() 48 | 49 | enable_merge_weight = False 50 | 51 | if enable_merge_weight: 52 | # 合并lora 权重 保存 53 | pl_model.save_sft_weight(os.path.join(ckpt_dir, 'pytorch_model_merge.bin'), merge_lora_weight=True) 54 | else: 55 | model = pl_model.get_llm_model() 56 | 57 | text_list = ["写一个诗歌,关于冬天", 58 | "晚上睡不着应该怎么办", 59 | "从南京到上海的路线", 60 | ] 61 | for input in text_list: 62 | response = Generate.generate(model, query=build_template(input), tokenizer=tokenizer, max_length=512, 63 | eos_token_id=config.eos_token_id, 64 | do_sample=False, top_p=0.7, temperature=0.95, ) 65 | print('input', input) 66 | print('output', response) -------------------------------------------------------------------------------- /src/training/infer_muti_lora_finetuning.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/4/2 22:49 2 | # @Author : tk 3 | # @FileName: infer_lora_finetuning 4 | import os 5 | import torch 6 | from deep_training.data_helper import ModelArguments 7 | from transformers import HfArgumentParser,AutoConfig 8 | from data_utils import train_info_args, NN_DataHelper,global_args,build_template 9 | from aigc_zoo.model_zoo.llm.llm_model import MyTransformer,\ 10 | PetlArguments,PromptArguments,PetlModel 11 | from aigc_zoo.utils.llm_generate import Generate 12 | 13 | 14 | if __name__ == '__main__': 15 | train_info_args['seed'] = None 16 | parser = HfArgumentParser((ModelArguments,)) 17 | (model_args,) = parser.parse_dict(train_info_args, allow_extra_keys=True) 18 | 19 | 20 | dataHelper = NN_DataHelper(model_args) 21 | tokenizer, _, _, _ = dataHelper.load_tokenizer_and_config() 22 | 23 | 24 | # 一般根据时间排序选最新的权重文件夹 25 | ckpt_dir = './best_ckpt/last' 26 | 27 | config = AutoConfig.from_pretrained(ckpt_dir) 28 | lora_args = PetlArguments.from_pretrained(ckpt_dir) 29 | 30 | assert lora_args.inference_mode == True 31 | 32 | new_num_tokens = config.vocab_size 33 | if config.task_specific_params is not None and config.task_specific_params.get('vocab_size',None) is not None: 34 | config.vocab_size = config.task_specific_params['vocab_size'] 35 | 36 | pl_model = MyTransformer(config=config, model_args=model_args, 37 | lora_args=lora_args, 38 | torch_dtype=config.torch_dtype, 39 | new_num_tokens=new_num_tokens, 40 | # load_in_8bit=global_args["load_in_8bit"], 41 | # # device_map="auto", 42 | # device_map = {"":0} # 第一块卡 43 | ) 44 | 45 | # 加载多个lora权重 46 | pl_model.load_sft_weight(ckpt_dir, adapter_name="default") 47 | 48 | # 加载多个lora权重 49 | # pl_model.load_sft_weight(ckpt_dir, adapter_name="yourname") 50 | 51 | # 加载多个lora权重 52 | # pl_model.load_sft_weight(ckpt_dir, adapter_name="yourname") 53 | 54 | pl_model.eval().half().cuda() 55 | 56 | # backbone model replaced PetlModel 57 | lora_model: PetlModel = pl_model.backbone 58 | 59 | text_list = ["写一个诗歌,关于冬天", 60 | "晚上睡不着应该怎么办", 61 | "从南京到上海的路线", 62 | ] 63 | 64 | # 基准模型推理 65 | with lora_model.disable_adapter(): 66 | for input in text_list: 67 | # lora_model 调用子对象方法 68 | response = Generate.generate(lora_model, query=build_template(input), tokenizer=tokenizer, max_length=512, 69 | eos_token_id=config.eos_token_id, 70 | do_sample=False, top_p=0.7, temperature=0.95, ) 71 | print('input', input) 72 | print('output', response) 73 | 74 | lora_model.set_adapter(adapter_name='default') 75 | 76 | for input in text_list: 77 | # lora_model 调用子对象方法 78 | response = Generate.generate(lora_model, query=build_template(input), tokenizer=tokenizer, max_length=512, 79 | eos_token_id=config.eos_token_id, 80 | do_sample=False, top_p=0.7, temperature=0.95, ) 81 | print('input', input) 82 | print('output', response) 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /src/training/infer_ptuning.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/4/2 22:49 2 | # @Author : tk 3 | # @FileName: infer_ptuning 4 | import os 5 | import torch 6 | from deep_training.data_helper import ModelArguments 7 | from transformers import HfArgumentParser,AutoConfig 8 | from data_utils import train_info_args, NN_DataHelper,build_template 9 | from aigc_zoo.model_zoo.llm.llm_model import MyTransformer,PromptArguments 10 | from aigc_zoo.utils.llm_generate import Generate 11 | 12 | if __name__ == '__main__': 13 | train_info_args['seed'] = None 14 | parser = HfArgumentParser((ModelArguments,)) 15 | (model_args,) = parser.parse_dict(train_info_args,allow_extra_keys=True) 16 | 17 | dataHelper = NN_DataHelper(model_args) 18 | tokenizer, _, _, _ = dataHelper.load_tokenizer_and_config(config_kwargs={"torch_dtype": torch.float16}) 19 | 20 | 21 | train_weight_dir = './best_ckpt/last' 22 | config = AutoConfig.from_pretrained(train_weight_dir) 23 | prompt_args = PromptArguments.from_pretrained(train_weight_dir) 24 | 25 | assert prompt_args.inference_mode == True 26 | 27 | new_num_tokens = config.vocab_size 28 | if config.task_specific_params is not None and config.task_specific_params.get('vocab_size', None) is not None: 29 | config.vocab_size = config.task_specific_params['vocab_size'] 30 | 31 | pl_model = MyTransformer(config=config, model_args=model_args, 32 | prompt_args=prompt_args, 33 | new_num_tokens=new_num_tokens, 34 | ) 35 | # 加载sft权重 36 | pl_model.load_sft_weight(train_weight_dir) 37 | 38 | pl_model.eval().half().cuda() 39 | 40 | model = pl_model.get_llm_model() 41 | 42 | #基础模型精度 43 | model.base_model_torch_dtype = torch.half 44 | 45 | text_list = ["写一个诗歌,关于冬天", 46 | "晚上睡不着应该怎么办", 47 | "从南京到上海的路线"] 48 | for input in text_list: 49 | for input in text_list: 50 | response = Generate.generate(model, query=build_template(input), tokenizer=tokenizer, max_length=512, 51 | eos_token_id=config.eos_token_id, 52 | do_sample=False, top_p=0.7, temperature=0.95, ) 53 | print('input', input) 54 | print('output', response) -------------------------------------------------------------------------------- /src/training/module_setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : ssbuild 3 | # @Time : 2023/8/16 16:03 4 | 5 | from deep_training.utils.hf import register_transformer_model,register_transformer_config 6 | from transformers import AutoModelForCausalLM 7 | from deep_training.nlp.models.rellama.modeling_llama import LlamaForCausalLM 8 | __all__ = [ 9 | "module_setup" 10 | ] 11 | 12 | def module_setup(): 13 | # 导入模型 14 | #register_transformer_config(XverseConfig) 15 | register_transformer_model(LlamaForCausalLM, AutoModelForCausalLM) -------------------------------------------------------------------------------- /src/training/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | import os.path 4 | import torch 5 | from deep_training.data_helper import ModelArguments, DataArguments, TrainingArguments 6 | from deep_training.trainer.pl.modelcheckpoint import ModelCheckpointEx 7 | from lightning import Trainer 8 | from lightning.pytorch.callbacks import LearningRateMonitor 9 | from lightning.pytorch.strategies import DeepSpeedStrategy 10 | from transformers import HfArgumentParser 11 | from data_utils import NN_DataHelper, train_info_args, get_deepspeed_config, global_args 12 | from aigc_zoo.model_zoo.llm.llm_model import MyTransformer, PetlArguments, LoraConfig, PromptArguments 13 | 14 | 15 | assert global_args["trainer_backend"] == "pl" 16 | 17 | if __name__ == '__main__': 18 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, PetlArguments,PromptArguments)) 19 | model_args, training_args, data_args, lora_args,prompt_args = parser.parse_dict(train_info_args) 20 | lora_args = lora_args.config 21 | prompt_args = prompt_args.config 22 | 23 | output_weight_dir = data_args.output_dir + '/best_ckpt' 24 | 25 | 26 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 27 | config_kwargs = {"torch_dtype": torch.float16} 28 | if global_args['config_merge']: 29 | config_kwargs.update(global_args['config_merge']) 30 | 31 | tokenizer, config, _, _ = dataHelper.load_tokenizer_and_config(config_kwargs=config_kwargs) 32 | 33 | dataHelper.make_dataset_all() 34 | 35 | is_bf16_supported = torch.cuda.is_bf16_supported() 36 | # 精度 根据实际情况做调整 37 | if is_bf16_supported: 38 | precision = 'bf16' 39 | else: 40 | precision = '16' 41 | 42 | if global_args["quantization_config"] is not None and global_args["quantization_config"].load_in_8bit: 43 | precision = "32" 44 | 45 | deepspeed_config = get_deepspeed_config(precision) 46 | strategy = 'ddp' if torch.cuda.device_count() > 1 else 'auto' 47 | if deepspeed_config is not None and len(deepspeed_config): 48 | warmup_ratio = 0.1 49 | with open(train_info_args['train_file'][0]) as f: 50 | total_steps = len(f.readlines()) * train_info_args['max_epochs'] 51 | total_steps /= len(train_info_args['devices']) * train_info_args['train_batch_size'] 52 | deepspeed_config['scheduler']['params']['warmup_num_steps'] = int(total_steps*warmup_ratio) 53 | deepspeed_config['scheduler']['params']['total_num_steps'] = int(total_steps) 54 | print("total steps: ", int(total_steps)) 55 | print("steps per epoch: ", int(total_steps/train_info_args['max_epochs'])) 56 | # from IPython import embed 57 | # embed() 58 | # exit() 59 | strategy = DeepSpeedStrategy(config=deepspeed_config, ) 60 | checkpoint_callback = ModelCheckpointEx( 61 | # monitor='loss', 62 | dirpath=output_weight_dir, 63 | save_weights_only=True, 64 | save_last=False, 65 | save_top_k=-1, 66 | # every_n_train_steps=2000 // training_args.gradient_accumulation_steps, 67 | every_n_epochs=1, 68 | lora_args=lora_args, 69 | prompt_args=prompt_args, 70 | ) 71 | 72 | 73 | trainer = Trainer( 74 | callbacks=[checkpoint_callback, LearningRateMonitor(logging_interval='step')], 75 | max_epochs=training_args.max_epochs, 76 | max_steps=training_args.max_steps, 77 | # max_steps=1, 78 | accelerator="gpu", 79 | devices=data_args.devices, 80 | enable_progress_bar=True, 81 | default_root_dir=data_args.output_dir, 82 | gradient_clip_val=training_args.max_grad_norm, 83 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 84 | num_sanity_val_steps=0, 85 | strategy=strategy, 86 | log_every_n_steps=1, 87 | # lora int8 precision='32' 88 | precision=precision,# 可以自行尝试 "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed" 89 | ) 90 | 91 | 92 | transformer_args = dict(config=config, model_args=model_args, training_args=training_args, lora_args=lora_args, prompt_args=prompt_args, 93 | quantization_config=global_args["quantization_config"], 94 | device_map={"": trainer.local_rank} if trainer.world_size > 1 else "auto", 95 | torch_dtype=torch.float16, 96 | new_num_tokens=len(tokenizer), # 可能扩充词 97 | ) 98 | 99 | if transformer_args["quantization_config"] is None: 100 | transformer_args.pop("device_map") 101 | 102 | pl_model = MyTransformer(**transformer_args) 103 | 104 | config.save_pretrained(output_weight_dir) 105 | 106 | # 加载sft权重 107 | # pl_model.load_sft_weight('./best_ckpt/best.pt',is_trainable=True) 108 | 109 | # pl_model = pl_model.float() if not is_bf16_supported else pl_model.bfloat16() 110 | 111 | def dataset_loader_filter_fn(dataset): 112 | print('*' * 30, 'total', len(dataset)) 113 | return dataset 114 | 115 | 116 | train_datasets = dataHelper.load_distributed_random_sampler( 117 | dataHelper.train_files, 118 | with_load_memory=data_args.data_backend == 'record', 119 | collate_fn=dataHelper.collate_fn, 120 | batch_size=training_args.train_batch_size, 121 | drop_last=training_args.dataloader_drop_last, # 多卡建议扔掉 122 | num_processes=trainer.world_size, process_index=trainer.global_rank, 123 | dataset_loader_filter_fn=dataset_loader_filter_fn, 124 | num_workers=training_args.dataloader_num_workers, 125 | pin_memory=training_args.dataloader_pin_memory, 126 | ) 127 | 128 | if train_datasets is not None: 129 | trainer.fit(pl_model, train_dataloaders=train_datasets) 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /src/training/train_hf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : ssbuild 3 | # @Time : 2023/9/25 12:29 4 | 5 | 6 | import logging 7 | import math 8 | import os 9 | import sys 10 | import datasets 11 | import torch 12 | import transformers 13 | from deep_training.trainer.hf.trainer import TrainerHF 14 | from transformers import ( 15 | HfArgumentParser, 16 | default_data_collator, 17 | set_seed, 18 | ) 19 | from transformers.trainer_utils import get_last_checkpoint 20 | from transformers.utils import check_min_version, send_example_telemetry 21 | from transformers.utils.versions import require_version 22 | from data_utils import NN_DataHelper, train_info_args, get_deepspeed_config, global_args 23 | from aigc_zoo.model_zoo.llm.llm_model import MyTransformer, PetlArguments, LoraConfig, PromptArguments 24 | from deep_training.data_helper import ModelArguments, DataArguments,TrainingArgumentsHF 25 | 26 | assert global_args["trainer_backend"] == "hf" 27 | 28 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 29 | check_min_version("4.33.2") 30 | 31 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | # Setup logging 36 | logging.basicConfig( 37 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 38 | datefmt="%m/%d/%Y %H:%M:%S", 39 | handlers=[logging.StreamHandler(sys.stdout)], 40 | ) 41 | 42 | def main(): 43 | training_args: TrainingArgumentsHF 44 | parser = HfArgumentParser((ModelArguments, TrainingArgumentsHF, DataArguments, PetlArguments, PromptArguments), 45 | conflict_handler='resolve') 46 | model_args, training_args, data_args, lora_args, prompt_args = parser.parse_dict(train_info_args,allow_extra_keys=True,) 47 | lora_args = lora_args.config 48 | prompt_args = prompt_args.config 49 | 50 | if training_args.should_log: 51 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 52 | transformers.utils.logging.set_verbosity_info() 53 | 54 | log_level = training_args.get_process_log_level() 55 | logger.setLevel(log_level) 56 | datasets.utils.logging.set_verbosity(log_level) 57 | transformers.utils.logging.set_verbosity(log_level) 58 | transformers.utils.logging.enable_default_handler() 59 | transformers.utils.logging.enable_explicit_format() 60 | 61 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 62 | config_kwargs = {"torch_dtype": torch.float16} 63 | if global_args['config_merge']: 64 | config_kwargs.update(global_args['config_merge']) 65 | 66 | tokenizer, config, _, _ = dataHelper.load_tokenizer_and_config(config_kwargs=config_kwargs) 67 | 68 | with training_args.main_process_first(desc="make_dataset_all"): 69 | dataHelper.make_dataset_all() 70 | 71 | is_bf16_supported = torch.cuda.is_bf16_supported() 72 | # 精度 根据实际情况做调整 73 | if is_bf16_supported: 74 | precision = 'bf16' 75 | else: 76 | precision = '16' 77 | 78 | if global_args["quantization_config"] is not None and global_args["quantization_config"].load_in_8bit: 79 | precision = "32" 80 | 81 | 82 | if str(precision) == '16': 83 | training_args.fp16 = True 84 | elif str(precision) == 'bf16': 85 | training_args.bf16 = True 86 | else: 87 | training_args.fp16 = False 88 | training_args.bf16 = False 89 | 90 | deepspeed_config = get_deepspeed_config(precision) 91 | if deepspeed_config: 92 | training_args.deepspeed = deepspeed_config 93 | 94 | # Log on each process the small summary: 95 | logger.warning( 96 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 97 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" 98 | ) 99 | logger.info(f"Training/evaluation parameters {training_args}") 100 | 101 | # Detecting last checkpoint. 102 | last_checkpoint = None 103 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 104 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 105 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 106 | raise ValueError( 107 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 108 | "Use --overwrite_output_dir to overcome." 109 | ) 110 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 111 | logger.info( 112 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 113 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 114 | ) 115 | 116 | # Set seed before initializing model. 117 | set_seed(training_args.seed) 118 | 119 | world_size,local_rank,process_index = training_args.world_size,training_args.local_rank,training_args.process_index 120 | 121 | transformer_args = dict(config=config, model_args=model_args, training_args=training_args, lora_args=lora_args, 122 | prompt_args=prompt_args, 123 | quantization_config=global_args["quantization_config"], 124 | device_map={"": local_rank} if world_size > 1 else "auto", 125 | torch_dtype=torch.float16, 126 | new_num_tokens=len(tokenizer), # 可能扩充词 127 | ) 128 | 129 | if transformer_args["quantization_config"] is None: 130 | transformer_args.pop("device_map") 131 | 132 | pl_model = MyTransformer(**transformer_args) 133 | 134 | config.save_pretrained(training_args.output_dir) 135 | 136 | # 加载sft权重 137 | # pl_model.load_sft_weight('./best_ckpt/best.pt',is_trainable=True) 138 | 139 | pl_model = pl_model.float() if not is_bf16_supported else pl_model.bfloat16() 140 | 141 | train_datasets = None 142 | if training_args.do_train: 143 | train_datasets = dataHelper.load_distributed_random_sampler( 144 | dataHelper.train_files, 145 | with_load_memory=data_args.data_backend == 'record', 146 | collate_fn=dataHelper.collate_fn, 147 | batch_size=training_args.train_batch_size, 148 | drop_last=training_args.dataloader_drop_last, # 多卡建议扔掉 149 | num_processes=world_size, process_index=process_index, 150 | num_workers = training_args.dataloader_num_workers, 151 | pin_memory = training_args.dataloader_pin_memory, 152 | ) 153 | 154 | 155 | 156 | # Initialize our Trainer 157 | trainer = TrainerHF( 158 | model=pl_model, 159 | args=training_args, 160 | train_dataset=train_datasets, 161 | tokenizer=tokenizer, 162 | # Data collator will default to DataCollatorWithPadding, so we change it. 163 | data_collator=default_data_collator, 164 | ) 165 | 166 | # Training 167 | if training_args.do_train: 168 | checkpoint = None 169 | if training_args.resume_from_checkpoint is not None: 170 | checkpoint = training_args.resume_from_checkpoint 171 | elif last_checkpoint is not None: 172 | checkpoint = last_checkpoint 173 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 174 | trainer.save_model() # Saves the tokenizer too for easy upload 175 | 176 | metrics = train_result.metrics 177 | metrics["train_samples"] = len(train_datasets) 178 | trainer.log_metrics("train", metrics) 179 | trainer.save_metrics("train", metrics) 180 | trainer.save_state() 181 | 182 | 183 | 184 | 185 | def _mp_fn(index): 186 | # For xla_spawn (TPUs) 187 | main() 188 | 189 | 190 | if __name__ == "__main__": 191 | main() 192 | --------------------------------------------------------------------------------