├── LICENSE
├── assets
├── cover.png
└── intro.png
├── data
├── alpaca_reproduced
│ └── data_52k.json
├── data_construction
│ ├── examples_ctx.json
│ └── examples_no_ctx.json
└── testset
│ ├── bpo_test.json
│ ├── dolly_eval.json
│ ├── self_instruct_eval.json
│ └── vicuna_eval.jsonl
├── readme.md
├── requirements.txt
└── src
├── data_construction
├── chatgpt_infer.py
├── process_en.py
└── process_optimized_prompts.py
├── evaluation
├── cal_claude_score.py
├── cal_gpt4_score.py
├── claude_score.py
├── gpt4_score.py
└── ranking_prompt.txt
├── infer_example.py
├── inference
└── llama2_infer.py
└── training
├── config
├── __init__.py
├── constant_map.py
├── deepspeed.json
├── deepspeed_offload.json
├── main.py
├── sft_config.py
├── sft_config_lora.py
└── sft_config_ptv2.py
├── data
└── finetune_train_examples.jsonl
├── data_processer.py
├── data_utils.py
├── infer.py
├── infer_finetuning.py
├── infer_lora_finetuning.py
├── infer_muti_lora_finetuning.py
├── infer_ptuning.py
├── module_setup.py
├── train.py
└── train_hf.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/assets/cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thu-coai/BPO/9bf541587d1456fda6ffa46001ab545e892f4a2c/assets/cover.png
--------------------------------------------------------------------------------
/assets/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thu-coai/BPO/9bf541587d1456fda6ffa46001ab545e892f4a2c/assets/intro.png
--------------------------------------------------------------------------------
/data/data_construction/examples_ctx.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "instruction": "Write a story about a person who has to overcome adversity to succeed.",
4 | "context": "",
5 | "good_res": "Once upon a time, there was a young man named John. John was born into a poor family, and he had to struggle to make ends meet. Growing up, he had always dreamed of going to college and getting a degree, but his family could not afford to send him. Despite this, John did not give up on his dreams. Instead, he worked hard in school, studied relentlessly, and applied for scholarships.\n\nUnfortunately, John faced many barriers along the way. His high school was in a rough neighborhood and did not have the resources to provide the education he needed. He often struggled to focus on his studies because he was hungry or worried about his family’s finances. Still, John was determined to succeed, and he poured all of his energy into his schoolwork.\n\nFinally, after years of hard work, John was awarded a full scholarship to college. He was overjoyed – his dream was within reach. However, even in college, John still faced adversity. He discovered that he had a learning disability that made it difficult for him to keep up with his coursework. It was frustrating and discouraging, but John refused to let it hold him back.\n\nHe found a tutor and worked with them tirelessly, often late into the night. He also joined study groups and sought help from his professors. Slowly but surely, John began to overcome his disability. He gained confidence in himself and his abilities, and he started to excel in his classes.\n\nIn the end, John graduated at the top of his class, with a degree in engineering. He had overcome incredible adversity to achieve his dreams, and he was proud of all that he had accomplished. With his degree in hand, John landed a high-paying job and was able to provide for his family.\n\nJohn’s story is a testament to the strength of the human spirit. Against all odds, he overcame adversity and achieved his goals. He proved that with hard work, perseverance, and a relentless determination to succeed, anything is possible.",
6 | "bad_res": "The story is about a young woman who has been dealt a difficult hand in life. She has been let go from her job and is struggling to make ends meet. Despite all the hardship, she finds strength and resilience from within. With determination and hard work, she pursues her dreams and eventually finds success. Although life is not easy, she never gives up and perseveres to achieve her goals."
7 | },
8 | {
9 | "instruction": "Come up with a data story involving the following data set.",
10 | "context": "A data set on the average amount of time spent in meetings per day by people who work in the same industry",
11 | "good_res": "As an AI assistant, I have analyzed a dataset that provides information on the average amount of time spent in meetings per day by people who work in the same industry. This data set offers valuable insights into the meeting culture of different industries and is perfect for companies looking to improve their meeting strategy.\n\nThe data reveals that the average time spent in meetings varied significantly by industry. For instance, on average, employees in the technology industry spent the highest amount of time in meetings, with an average of nearly two hours per day, followed by those in the finance industry at an average of 1.5 hours per day. On the other hand, employees in the retail industry had the lowest average meeting time, clocking at less than half an hour per day.\n\nFurther analysis of the data indicates that there is a positive correlation between the average meeting time per day and the level of collaboration required in the industry. Industries that require a high level of collaboration, such as tech and finance, also have the highest amount of time spent in meetings. Conversely, industries with a lower level of collaboration, such as retail, have a shorter average meeting time.\n\nWe may also infer that the style of meetings and their effectiveness also vary by industry. For example, shorter, more focused meetings in the retail industry may lead to increased productivity, whereas longer meetings in the technology industry may allow for more innovation and brainstorming.\n\nIn conclusion, this data story suggests that the amount of time spent in meetings is influenced by the nature of the industry, level of collaboration and meeting style. Companies looking to optimize their meeting strategy can leverage this data to make informed decisions about the duration, frequency and format of their meetings.",
12 | "bad_res": "The data reveals that people working in the same industry spend drastically different amounts of time in meetings on average. For example, while some industries require as much as 8.5 hours a day spent in short, frequent meetings, others require as little as 2.5 hours. This suggests that some industries are much more highly dependent on collaborative work, while others may focus more on solo tasks. Knowing this can help shape decisions about the organizational culture and communication between team members in different industries. Furthermore, the data can provide useful information to individuals looking to advance in their respective industries."
13 | }
14 | ]
--------------------------------------------------------------------------------
/data/data_construction/examples_no_ctx.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "instruction": "Name a famous scientist who won a Nobel Prize.",
4 | "good_res": "One famous scientist who won a Nobel Prize is Marie Curie. She was awarded the Nobel Prize in Physics in 1903 for her work on radioactivity, and then again in 1911 for her discovery of the elements radium and polonium, making her the first person to win the Nobel Prize in two different scientific fields.",
5 | "bad_res": "Albert Einstein is a famous scientist who won the Nobel Prize in Physics in 1921."
6 | },
7 | {
8 | "instruction": "Generate a business idea focused on using virtual reality.",
9 | "good_res": "One innovative business idea that involves using virtual reality is creating immersive virtual experiences for home shoppers. This would involve setting up a service that can provide potential homebuyers with a realistic virtual tour of properties they are interested in, all from the comfort of their own homes. This would save time and money for both the buyer and the seller, as they wouldn't need to schedule multiple in-person showings.\n\nClients would put on a virtual reality headset and be transported into a fully immersive, life-like representation of the property they're interested in. This virtual tour would allow homebuyers to walk through the entire property, explore different rooms, and see features up close. This level of immersion would help buyers make informed decisions when it comes to purchasing a property.\n\nTo make this business unique, the virtual tours could be personalized and interactive. Customers could choose different times of day and weather conditions to see how the property looks, add and remove furniture, and make other changes to get a better feel for the space.\n\nThis business idea combines the power of virtual reality with the convenience of shopping from home. It has the potential to revolutionize the homebuying process, making it easier, more efficient, and more enjoyable for everyone involved.",
10 | "bad_res": "A business idea focused on using virtual reality could be to create a virtual reality simulator for businesses. This simulator could allow companies to embed various scenarios into their employees' training and to help them learn and develop new skills. Companies could also use the simulator to test out new strategies or products in a virtual environment."
11 | }
12 | ]
--------------------------------------------------------------------------------
/data/testset/vicuna_eval.jsonl:
--------------------------------------------------------------------------------
1 | {"question_id": 1, "text": "How can I improve my time management skills?", "category": "generic"}
2 | {"question_id": 2, "text": "What are the most effective ways to deal with stress?", "category": "generic"}
3 | {"question_id": 3, "text": "What are the main differences between Python and JavaScript programming languages?", "category": "generic"}
4 | {"question_id": 4, "text": "How can I increase my productivity while working from home?", "category": "generic"}
5 | {"question_id": 5, "text": "Can you explain the basics of quantum computing?", "category": "generic"}
6 | {"question_id": 6, "text": "What are the differences between plant-based and animal-based protein sources?", "category": "generic"}
7 | {"question_id": 7, "text": "How can I develop my critical thinking skills?", "category": "generic"}
8 | {"question_id": 8, "text": "What are the major challenges faced by the education sector today?", "category": "generic"}
9 | {"question_id": 9, "text": "What are the primary factors that influence consumer behavior?", "category": "generic"}
10 | {"question_id": 10, "text": "What are the most effective strategies for conflict resolution in the workplace?", "category": "generic"}
11 | {"question_id": 11, "text": "What are some potential implications of using a single-use plastic bottle versus a reusable bottle on both the environment and human health?", "category": "knowledge"}
12 | {"question_id": 12, "text": "What factors would you consider when designing an inclusive and accessible public transportation system?", "category": "knowledge"}
13 | {"question_id": 13, "text": "How can governments utilize fiscal and monetary policies to combat economic recessions?", "category": "knowledge"}
14 | {"question_id": 14, "text": "How do language and cultural barriers affect the way people communicate and form relationships in multicultural societies?", "category": "knowledge"}
15 | {"question_id": 15, "text": "Describe a scenario where artificial intelligence could be used to improve the quality and efficiency of healthcare delivery.", "category": "knowledge"}
16 | {"question_id": 16, "text": "Explain the process of gene editing using CRISPR-Cas9 technology, and discuss its potential applications and ethical implications.", "category": "knowledge"}
17 | {"question_id": 17, "text": "How do vaccinations work to protect individuals and communities from infectious diseases, and what is herd immunity?", "category": "knowledge"}
18 | {"question_id": 18, "text": "How do social media platforms influence the way people consume and share news, and what are the potential implications for the spread of misinformation?", "category": "knowledge"}
19 | {"question_id": 19, "text": "How do cultural, social, and economic factors influence people's food choices, and how can this knowledge be used to promote healthier diets?", "category": "knowledge"}
20 | {"question_id": 20, "text": "Explain the process of natural selection and how it contributes to the evolution and adaptation of species.", "category": "knowledge"}
21 | {"question_id": 21, "text": "How would you introduce yourself as a medieval knight at a royal banquet?", "category": "roleplay"}
22 | {"question_id": 22, "text": "As a pirate captain, what would you say to your crew to motivate them to search for hidden treasure?", "category": "roleplay"}
23 | {"question_id": 23, "text": "If you were a Shakespearean character, how would you declare your love for someone in a soliloquy?", "category": "roleplay"}
24 | {"question_id": 24, "text": "As a superhero, how would you explain your origin story to a curious child?", "category": "roleplay"}
25 | {"question_id": 25, "text": "Imagine you are a time traveler from the year 3000. What technological advancements would you tell people about?", "category": "roleplay"}
26 | {"question_id": 26, "text": "As a sports commentator, describe the winning play in the final seconds of a championship game.", "category": "roleplay"}
27 | {"question_id": 27, "text": "Pretend to be a world-famous chef. How would you describe your signature dish to a panel of judges?", "category": "roleplay"}
28 | {"question_id": 28, "text": "You are a mountain climber reaching the summit of Mount Everest. Describe your emotions and the view from the top.", "category": "roleplay"}
29 | {"question_id": 29, "text": "As a space colonist on Mars, describe your daily life and the challenges you face living on another planet.", "category": "roleplay"}
30 | {"question_id": 30, "text": "Pretend to be a character in a post-apocalyptic world. Describe how you survive and the allies you encounter.", "category": "roleplay"}
31 | {"question_id": 31, "text": "How can you determine if a restaurant is popular among locals or mainly attracts tourists, and why might this information be useful?", "category": "common-sense"}
32 | {"question_id": 32, "text": "What are some subtle clues that suggest someone is pretending to understand a topic or conversation when they are actually confused or uninformed?", "category": "common-sense"}
33 | {"question_id": 33, "text": "Why might someone choose to use a paper map or ask for directions instead of relying on a GPS device or smartphone app?", "category": "common-sense"}
34 | {"question_id": 34, "text": "How can you determine if a person is genuinely interested in a conversation or simply being polite?", "category": "common-sense"}
35 | {"question_id": 35, "text": "Why might someone prefer to shop at a small, locally-owned business instead of a large chain store, even if the prices are higher?", "category": "common-sense"}
36 | {"question_id": 36, "text": "How can you assess the credibility of a source of information, such as a news article or blog post, without relying solely on the reputation of the author or publisher?", "category": "common-sense"}
37 | {"question_id": 37, "text": "Why do some people enjoy the sensation of being scared, such as by watching horror movies or going on roller coasters, while others avoid these experiences?", "category": "common-sense"}
38 | {"question_id": 38, "text": "How can observing the behavior of other people in a social situation provide clues about cultural norms and expectations?", "category": "common-sense"}
39 | {"question_id": 39, "text": "Do we have a moral obligation to explore space, or should we focus on solving Earth's problems first?", "category": "common-sense"}
40 | {"question_id": 40, "text": "In a world where automation is becoming increasingly prevalent, is it more important to prioritize job creation or technological progress?", "category": "common-sense"}
41 | {"question_id": 41, "text": "How many times does the average human blink in a lifetime? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
42 | {"question_id": 42, "text": "How many atoms are in a grain of salt? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
43 | {"question_id": 43, "text": "How many lightning strikes occur on Earth each day? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
44 | {"question_id": 44, "text": "How many balloons would it take to lift a house like in the movie \"Up\"? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
45 | {"question_id": 45, "text": "How many text messages are sent globally in a minute? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
46 | {"question_id": 46, "text": "How many words are spoken daily on Earth? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
47 | {"question_id": 47, "text": "How many snowflakes fall during a typical winter? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
48 | {"question_id": 48, "text": "How many pages are in all the books ever written? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
49 | {"question_id": 49, "text": "How many times has the Earth orbited the Sun since the beginning of life? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
50 | {"question_id": 50, "text": "How many songs have been recorded throughout history? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step.", "category": "fermi"}
51 | {"question_id": 51, "text": "What if the Internet had been invented during the Renaissance period?", "category": "counterfactual"}
52 | {"question_id": 52, "text": "What if the Aztecs had successfully repelled the Spanish conquistadors?", "category": "counterfactual"}
53 | {"question_id": 53, "text": "What if the Black Death had not occurred in the 14th century?", "category": "counterfactual"}
54 | {"question_id": 54, "text": "What if Isaac Newton had focused on biology instead of physics?", "category": "counterfactual"}
55 | {"question_id": 55, "text": "What if the Beatles had never formed as a band?", "category": "counterfactual"}
56 | {"question_id": 56, "text": "What if Alan Turing had not cracked the Enigma code during World War II?", "category": "counterfactual"}
57 | {"question_id": 57, "text": "What if the Suez Canal had never been constructed?", "category": "counterfactual"}
58 | {"question_id": 58, "text": "What if the Maya civilization had never mysteriously collapsed?", "category": "counterfactual"}
59 | {"question_id": 59, "text": "What if Christopher Columbus had not discovered the Americas?", "category": "counterfactual"}
60 | {"question_id": 60, "text": "What if Vincent van Gogh had been a successful artist during his lifetime?", "category": "counterfactual"}
61 | {"question_id": 61, "text": "Develop a C++ program that reads a text file line by line and counts the number of occurrences of a specific word in the file.", "category": "coding"}
62 | {"question_id": 62, "text": "Implement a Python function to find the longest common subsequence of two input strings using dynamic programming.", "category": "coding"}
63 | {"question_id": 63, "text": "Implement a regular expression in Python to validate an email address.", "category": "coding"}
64 | {"question_id": 64, "text": "Write a program to find the nth Fibonacci number using dynamic programming.", "category": "coding"}
65 | {"question_id": 65, "text": "Implement a binary search algorithm to find a specific element in a sorted array.", "category": "coding"}
66 | {"question_id": 66, "text": "Implement a queue data structure using two stacks in Python.", "category": "coding"}
67 | {"question_id": 67, "text": "Implement a program to find the common elements in two arrays without using any extra data structures.", "category": "coding"}
68 | {"question_id": 68, "text": "Given that f(x) = 5x^3 - 2x + 3, find the value of f(2).", "category": "math"}
69 | {"question_id": 69, "text": "Solve for x in the equation 3x + 10 = 5(x - 2).", "category": "math"}
70 | {"question_id": 70, "text": "If the endpoints of a line segment are (2, -2) and (10, 4), what is the length of the segment?", "category": "math"}
71 | {"question_id": 71, "text": "Can you help me write a formal email to a potential business partner proposing a joint venture?", "category": "writing"}
72 | {"question_id": 72, "text": "Can you help me write a resignation letter to my current employer, while leaving on good terms and expressing gratitude for the opportunities provided?", "category": "writing"}
73 | {"question_id": 73, "text": "Use an appropriate format to structure a formal letter of recommendation for a student applying to a prestigious graduate program in computer science.", "category": "writing"}
74 | {"question_id": 74, "text": "Write a compelling product launch announcement email to inform our customers of our new software solution.", "category": "writing"}
75 | {"question_id": 75, "text": "Draft an apology email to a customer who experienced a delay in their order, and provide reassurance that the issue has been resolved.", "category": "writing"}
76 | {"question_id": 76, "text": "Write a script for a YouTube video exploring the history and cultural significance of jazz.", "category": "writing"}
77 | {"question_id": 77, "text": "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.", "category": "writing"}
78 | {"question_id": 78, "text": "Write a captivating movie review for a recently released science fiction film, discussing its plot, characters, and special effects.", "category": "writing"}
79 | {"question_id": 79, "text": "Structure a podcast script for an episode discussing the influence of streaming platforms on the music industry.", "category": "writing"}
80 | {"question_id": 80, "text": "Write a symphony concert review, discussing the orchestra's performance and overall audience experience.", "category": "writing"}
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # Black-Box Prompt Optimization (BPO)
6 | ### Aligning Large Language Models without Model Training (ACL 2024)
7 |
8 |
9 | 🤗 Model • 📚 Data • 📃 Paper • 🌐 Demo
10 |
11 |
12 | (Upper) Black-box Prompt Optimization (BPO) offers a conceptually new perspective to bridge the gap between humans and LLMs. (Lower) On Vicuna Eval’s pairwise evaluation, we show that BPO further aligns gpt-3.5-turbo and claude-2 without training. It also outperforms both PPO & DPO and presents orthogonal improvements.
13 |
14 |
15 |

16 |
17 |
18 |
19 |
20 |
21 | ## Update
22 | We have released our [model](https://huggingface.co/THUDM/BPO) and [data](https://huggingface.co/datasets/THUDM/BPO) on Hugging Face.
23 |
24 | We build a [demo](https://huggingface.co/spaces/CCCCCC/BPO_demo) for BPO on Hugging Face.
25 |
26 |
27 | ## Table of Contents
28 | - [Model](#model)
29 | - [Data](#data)
30 | - [Quick Start](#quick-start)
31 | - [Data Construction](#data-construction)
32 | - [Model Training](#model-training)
33 | - [Inference](#inference)
34 | - [Evaluation](#evaluation)
35 | - [Citation](#citation)
36 |
37 |
38 | ## Model
39 | The prompt preference optimization model can be download from [Hugging Face](https://huggingface.co/THUDM/BPO)
40 |
41 | Inference code (Please refer to [src/infer_example.py](src/infer_example.py) for more instructions on how to optimize your prompts):
42 | ```python
43 | from transformers import AutoModelForCausalLM, AutoTokenizer
44 |
45 | model_path = 'THUDM/BPO'
46 |
47 | prompt_template = "[INST] You are an expert prompt engineer. Please help me improve this prompt to get a more helpful and harmless response:\n{} [/INST]"
48 |
49 | device = 'cuda:0'
50 | model = AutoModelForCausalLM.from_pretrained(model_path).half().eval().to(device)
51 | # for 8bit
52 | # model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, load_in_8bit=True)
53 | tokenizer = AutoTokenizer.from_pretrained(model_path)
54 |
55 | text = 'Tell me about Harry Potter'
56 |
57 | prompt = prompt_template.format(text)
58 | model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
59 | output = model.generate(**model_inputs, max_new_tokens=1024, do_sample=True, top_p=0.9, temperature=0.6, num_beams=1)
60 | resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].strip()
61 |
62 | print(resp)
63 | ```
64 |
65 | ## Data
66 |
67 | ### BPO dataset
68 | BPO Dataset can be found on [Hugging Face](https://huggingface.co/datasets/THUDM/BPO).
69 |
70 | ### BPO for SFT Data Construction
71 | The alpaca_reproduce directory contains the BPO-reproduced Alpaca dataset. The data format is:
72 | ```json
73 | {
74 | "instruction": {instruction},
75 | "input": {input},
76 | "output": {output},
77 | "optimized_prompt": {optimized_prompt},
78 | "res": {res}
79 | }
80 | ```
81 | - {instruction}, {input}, and {output} are elements from the original dataset.
82 | - {optimized_prompt} is BPO-optimized instruction.
83 | - {res} is the response from text-davinci-003 using the {optimized_prompt}.
84 |
85 |
86 | ### Testset
87 | The testset directory contains all the test datasets we used, including:
88 | - 200 prompts sampled from the BPO dataset
89 | - 200 examples from Dolly dataset
90 | - 252 human evaluation instructions from Self-Instruct
91 | - 80 user-oriented prompts from the Vicuna Eval dataset.
92 |
93 |
94 | ## Quick Start
95 | For all codes, we have added `#TODO` comments to indicate places in the code that need modification before running. Please update the relevant parts as noted before executing each file.
96 |
97 | ### Setup
98 | ```bash
99 | pip install -r requirements.txt
100 | ```
101 |
102 | ### Data Construction
103 | To construct data yourself, run the following command
104 | ```bash
105 | cd src/data_construction
106 |
107 | # using pairwise feedback data to generate optimized prompts
108 | python chatgpt_infer.py
109 |
110 | # process generated optimized prompts
111 | python process_optimized_prompts.py
112 | ```
113 |
114 | ### Model Training
115 | If you want to train your own prompt preference optimizer,
116 | please run the following command:
117 | ```bash
118 | cd src/training
119 |
120 | # pre-process fine-tuning data
121 | python ../data_construction/process_en.py
122 | python data_utils.py
123 |
124 | # fine-tuning
125 | python train.py
126 |
127 | # inference
128 | python infer_finetuning.py
129 | ```
130 |
131 | ### Inference
132 | We show an [example code](src/inference/llama2_infer.py) for generation with llama2-chat on BPO-optimized prompts.
133 |
134 | ### Evaluation
135 | If you wish to compare the BPO-aligned model with the original model, please refer to the following code:
136 | ```bash
137 | cd src/evaluation
138 |
139 | # take gpt4 evaluation on dolly_eval as an example
140 | python gpt4_score.py --input_file_a "Path to generation results of BPO-aligned model" \
141 | --input_file_b "Path to generation results of original model" \
142 | --task_name "dolly_eval" \ # change it to "self_instruct", "test_set", or "vicuna_eval" for other testsets
143 | --output_file "Output path"
144 |
145 | # calculate win rates
146 | python cal_gpt4_score.py --input_file "Output path"
147 | ```
148 |
149 |
150 | ## Acknowledgement
151 | - Fine-tuning code: [llm_finetuning](https://github.com/ssbuild/llm_finetuning)
152 | - PPO code: [DeepSpeed-Chat](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/README.md)
153 | - DPO code: [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)
154 | - Evaluation Prompts: [llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge) and [alpaca_eval](https://github.com/tatsu-lab/alpaca_eval)
155 |
156 | ## Citation
157 | ```
158 | @article{cheng2023black,
159 | title={Black-Box Prompt Optimization: Aligning Large Language Models without Model Training},
160 | author={Cheng, Jiale and Liu, Xiao and Zheng, Kehan and Ke, Pei and Wang, Hongning and Dong, Yuxiao and Tang, Jie and Huang, Minlie},
161 | journal={arXiv preprint arXiv:2311.04155},
162 | year={2023}
163 | }
164 | ```
165 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | requests
3 | transformers
4 | deepspeed
5 | aigc_zoo==0.2.4
6 | deep_training==0.2.4
--------------------------------------------------------------------------------
/src/data_construction/chatgpt_infer.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import json
3 | import os
4 | import time
5 | import random
6 |
7 | # TODO add api key
8 | API_KEY = 'Your-API-Key'
9 |
10 | HEADERS = {
11 | "Content-Type": "application/json",
12 | "Authorization": f"Bearer {API_KEY}"
13 | }
14 |
15 | API_URL = "https://api.openai.com/v1/chat/completions"
16 |
17 |
18 | def chat_gpt(messages, counter, error_count):
19 | responses = []
20 | for i, m in enumerate(messages):
21 | try:
22 | message = m['message']
23 | data = json.dumps({"model": "gpt-3.5-turbo", "messages": message, 'temperature':0.9})
24 | response = requests.post(API_URL, headers=HEADERS, data=data)
25 | response_json = response.json()
26 | res = response_json['choices'][0]['message']['content']
27 | m['response'] = res
28 | # save to file
29 | with open(output_file, 'a', encoding='utf-8') as f:
30 | print(json.dumps(m, ensure_ascii=False), file=f)
31 |
32 | responses.append(response_json)
33 |
34 | counter += 1
35 | except Exception as e:
36 | error_count += 1
37 | print(e)
38 | print('running time:{} finished number:{} skipped number:{}'.format(time.time()-s_time, counter, error_count), end='\r')
39 |
40 | return responses
41 |
42 |
43 | def get_messages_list():
44 | evaluated = []
45 | with open(output_file, encoding='utf-8') as f:
46 | lines = f.readlines()
47 | for i in lines:
48 | evaluated.append(json.loads(i)['origin'])
49 |
50 | with open(input_file, encoding='utf-8') as f:
51 | d = json.load(f)
52 |
53 | messages_list = []
54 |
55 | ctx_prompt = """instruction: "{}"
56 | context:
57 | "{}"
58 |
59 | bad response:
60 | "{}"
61 |
62 | good response:
63 | "{}"
64 |
65 | Compare the good response and bad response from these aspects: correctness (if the response follows the instruction correctly and give an accurate response, high priority), helpfulness(like depth, creativity, coherence) and harmlessness. Then be an expert prompt engineer and improve my instruction from the above aspects to get better responses like "good response" rather than "bad response".
66 |
67 | Pay attention to:
68 | 1.Don't forget any information in the original instruction. Focus on maintaining all the information in my instruction.
69 | 2.Please don't add too detailed content constraints related to the good response and not mentioned in the original instruction, unless in form of examples.
70 | 3.Don't change the context or add the context into the instruction, but rather optimize my instruction only. Don't give a response to my instruction.
71 | 4.Help me tune my prompt (the instruction) to get a better response while remaining the original meaning of the instruction and user intent.
72 |
73 | Output with the following format:
74 | Detailed Comparison Result: xxx
75 | Optimized Instruction: xxx [END]"""
76 |
77 | no_ctx_prompt = """instruction: "{}"
78 |
79 | bad response:
80 | "{}"
81 |
82 | good response:
83 | "{}"
84 |
85 | Compare the good response and bad response from these aspects: correctness (if the response follows the instruction correctly and give an accurate response, high priority), helpfulness(like depth, creativity, coherence) and harmlessness. Then be an expert prompt engineer and improve my instruction from the above aspects to get better responses like "good response" rather than "bad response".
86 |
87 | Pay attention to:
88 | 1.If the instruction contains any safety issues, please rewrite the original instructions to be completely harmless and safe under the same topic.
89 | 2.Don't forget any information in the original instruction. Focus on maintaining all the information in my instruction.
90 | 3.Please don't add too detailed content constraints related to the good response and not mentioned in the original instruction, unless in form of examples.
91 | 4.There may be some protected parts in the instruction, which means these parts should never be changed or lost. Please carefully protect these parts.
92 | 5.You should never generate a response to the original instruction!
93 | 6.Help me tune my prompt (the instruction) to get a better response while maintaining the original meaning of the instruction and the user intent.
94 |
95 | Output with the following format:
96 | Detailed Comparison Result: xxx
97 | Optimized Instruction: xxx [END]"""
98 |
99 | for i in d:
100 | if i in evaluated:
101 | continue
102 | if 'context' in i:
103 | text = ctx_prompt.format(i['instruction'], i['context'], i['bad_res'], i['good_res'])
104 | else:
105 | text = no_ctx_prompt.format(i['instruction'], i['bad_res'], i['good_res'])
106 | messages_list.append({
107 | 'message': [
108 | {"role": "user", "content": text}
109 | ],
110 | 'origin': i
111 | })
112 |
113 | return messages_list
114 |
115 |
116 | if __name__ == '__main__':
117 | # TODO input file and output file
118 | input_file = '../../data/data_construction/examples_ctx.json'
119 | output_file = '../../data/data_construction/examples_ctx_optimized.jsonl'
120 | if not os.path.exists(output_file):
121 | x = open(output_file, 'w')
122 | x.close()
123 | messages_list = get_messages_list()
124 | print("total num: ", len(messages_list))
125 | s_time = time.time()
126 | responses = chat_gpt(messages_list, 0, 0)
--------------------------------------------------------------------------------
/src/data_construction/process_en.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | with open('../../data/data_construction/examples_ctx_optimized.json', encoding='utf-8') as f:
5 | d = json.load(f)
6 |
7 | res = []
8 | num = 0
9 | for i in d:
10 | q = i['prompt']
11 | a = i['optimized_prompt']
12 | try:
13 | a = eval(a)
14 | except:
15 | pass
16 | res.append(json.dumps({
17 | 'id': num,
18 | "paragraph": [
19 | {
20 | 'q': q,
21 | 'a': a
22 | }
23 | ],
24 | }, ensure_ascii=False) + '\n')
25 | num += 1
26 |
27 | with open('data/train.jsonl', 'w', encoding='utf-8') as f:
28 | f.writelines(res)
29 |
30 |
--------------------------------------------------------------------------------
/src/data_construction/process_optimized_prompts.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from tqdm import trange, tqdm
4 | import os
5 |
6 | # Preprocess code for dataset with context, like Alpaca-gpt4
7 | def process_ctx(input_file, output_file):
8 |
9 | with open(input_file) as f:
10 | l = f.readlines()
11 |
12 | res = []
13 | for i in l:
14 | i = json.loads(i)
15 | response = i['response'].split('[END]')[0]
16 | if not response.count('Optimized Instruction:'):
17 | print(response)
18 | continue
19 | else:
20 | response = response.split('Optimized Instruction:')
21 | try:
22 | prompt = eval(response[1]).strip()
23 | except:
24 | prompt = response[1].strip()
25 | i['origin']['comparison'] = response[0]
26 | i['origin']['optimized_instruction'] = prompt
27 | res.append(i['origin'])
28 |
29 |
30 | data = []
31 | for i in tqdm(res):
32 | if not len(i['context']):
33 | i['prompt'] = i['instruction']
34 | i['optimized_prompt'] = i['optimized_instruction']
35 | else:
36 | # optimized instruction contains context
37 | if i['optimized_instruction'].lower().count(i['context'].lower()):
38 | i['prompt'] = (i['instruction'] + '\n' + i['context']).strip()
39 | i['optimized_prompt'] = i['optimized_instruction']
40 | else:
41 | # using the format {instruction}\n{context}
42 | if i['optimized_instruction'].count('follow') or i['instruction'].count('follow') or i['instruction'][-1] == ':':
43 | i['prompt'] = (i['instruction'] + '\n' + i['context']).strip()
44 | i['optimized_prompt'] = (i['optimized_instruction'] + '\n' + i['context']).strip()
45 | else:
46 | if random.random()< 0.5:
47 | if random.random() < 0.5:
48 | # using the format {instruction}\n{context}
49 | i['prompt'] = (i['instruction'] + '\n' + i['context']).strip()
50 | i['optimized_prompt'] = (i['optimized_instruction'] + '\n' + i['context']).strip()
51 | else:
52 | # using the format {context}\n{instruction}
53 | i['prompt'] = (i['context'] + '\n' + i['instruction']).strip()
54 | i['optimized_prompt'] = (i['context'] + '\n' + i['optimized_instruction']).strip()
55 | else:
56 | if random.random() < 0.25:
57 | if random.random() < 0.5:
58 | # using the format {instruction} {context}
59 | i['prompt'] = (i['instruction'] + ' ' + i['context']).strip()
60 | i['optimized_prompt'] = (i['optimized_instruction'] + ' ' + i['context']).strip()
61 | else:
62 | # using the format {context} {instruction}
63 | i['prompt'] = (i['context'] + ' ' + i['instruction']).strip()
64 | i['optimized_prompt'] = (i['context'] + ' ' + i['optimized_instruction']).strip()
65 | else:
66 | if random.random() < 0.5:
67 | # using the format {instruction} "{context}"
68 | i['prompt'] = (i['instruction'] + ' "' + i['context'] + '"').strip()
69 | i['optimized_prompt'] = (i['optimized_instruction'] + ' "' + i['context'] + '"').strip()
70 | else:
71 | # using the format {context} "{instruction}"
72 | i['prompt'] = ('"'+ i['context'] + '" ' + i['instruction']).strip()
73 | i['optimized_prompt'] = ('"' + i['context'] + '" ' + i['optimized_instruction']).strip()
74 | data.append(i)
75 |
76 | with open(output_file, 'w', encoding='utf-8') as f:
77 | json.dump(data, f, indent=4, ensure_ascii=False)
78 |
79 |
80 | # Preprocess code for dataset without context, like Chatbot Arena Conversation
81 | def process_no_ctx(input_file, output_file):
82 |
83 | with open(input_file) as f:
84 | l = f.readlines()
85 |
86 | res = []
87 | for i in l:
88 | i = json.loads(i)
89 | response = i['response'].split('[END]')[0]
90 | if not response.count('Optimized Instruction:'):
91 | print(response)
92 | continue
93 | else:
94 | response = response.split('Optimized Instruction:')
95 | try:
96 | prompt = eval(response[1]).strip()
97 | except:
98 | prompt = response[1].strip()
99 | i['origin']['comparison'] = response[0]
100 | i['origin']['optimized_instruction'] = prompt
101 | res.append(i['origin'])
102 |
103 | data = []
104 | num = 0
105 | for i in res:
106 | if len(i['instruction'].split()) / len(i['optimized_instruction'].split()) > 2 or len(i['optimized_instruction'].split()) / len(i['instruction'].split()) > 6:
107 | # filter data that may be error
108 | continue
109 | if i['optimized_instruction'].lower().count('[protected'):
110 | # filter data contains special string
111 | continue
112 | i['prompt'] = i['instruction']
113 | i['optimized_prompt'] = i['optimized_instruction']
114 | data.append(i)
115 |
116 | with open(output_file, 'w', encoding='utf-8') as f:
117 | json.dump(data, f, indent=4, ensure_ascii=False)
118 |
119 |
120 | if __name__ == '__main__':
121 | # TODO add input_file output_file
122 | input_file = '../../data/data_construction/examples_ctx_optimized.jsonl'
123 | output_file = '../../data/data_construction/examples_ctx_optimized.json'
124 |
125 | # TODO choose a function depend on your dataset
126 |
127 | # Preprocess code for dataset with context attribute, like Alpaca-gpt4
128 | process_ctx(input_file, output_file)
129 |
130 | # Preprocess code for dataset without context attribute, like Chatbot Arena Conversation
131 | # process_no_ctx(input_file, output_file)
--------------------------------------------------------------------------------
/src/evaluation/cal_claude_score.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 |
4 |
5 | def cal_overall(input_file, judge_key):
6 | with open(input_file) as f:
7 | l = f.readlines()
8 | w_l_t = [0, 0, 0]
9 | num = 0
10 |
11 | str_a = "model_1"
12 | str_b = "model_2"
13 |
14 | print(len(l))
15 | for i in l:
16 | i = json.loads(i)
17 | if i['response'].split('rank')[0].count(str_a):
18 | if judge_key in i['option_a']:
19 | num += 1
20 | w_l_t[1] += 1
21 | else:
22 | w_l_t[0] += 1
23 | elif i['response'].split('rank')[0].count(str_b):
24 | if judge_key in i['option_a']:
25 | num += 1
26 | w_l_t[0] += 1
27 | else:
28 | w_l_t[1] += 1
29 | else:
30 | print(i['response'].split('rank')[0])
31 | print(w_l_t)
32 | print(f"Origin v.s. {judge_key}, win lose tie: ", [i / len(l) for i in w_l_t])
33 | print(f"{judge_key} as first: ", num)
34 |
35 |
36 | if __name__ == '__main__':
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument('--input_file', type=str)
39 | args = parser.parse_args()
40 |
41 | # TODO there should be a special key in the dict to distinguish the source model, like 'optimized_prompt' will be in the optimized version
42 | judge_key = 'optimized_prompt'
43 | cal_overall(args.input_file, judge_key)
--------------------------------------------------------------------------------
/src/evaluation/cal_gpt4_score.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 |
4 |
5 | def cal_overall(input_file, judge_key):
6 | with open(input_file) as f:
7 | l = f.readlines()
8 | w_l_t = [0, 0, 0]
9 | num = 0
10 |
11 | for i in l:
12 | i = json.loads(i)
13 | if "[[A]]" in i['response'].split('\n\n')[-1]:
14 | if judge_key in i['option_a']:
15 | num += 1
16 | w_l_t[1] += 1
17 | else:
18 | w_l_t[0] += 1
19 | elif "[[B]]" in i['response'].split('\n\n')[-1]:
20 | if judge_key in i['option_a']:
21 | num += 1
22 | w_l_t[0] += 1
23 | else:
24 | w_l_t[1] += 1
25 | elif "[[C]]" in i['response'].split('\n\n')[-1]:
26 | if judge_key in i['option_a']:
27 | num += 1
28 | w_l_t[2] += 1
29 |
30 | print(w_l_t)
31 | print(f"Origin v.s. {judge_key}, win lose tie: ", [i/len(l) for i in w_l_t])
32 | print(f"{judge_key} as first: ", num)
33 |
34 |
35 | if __name__ == '__main__':
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument('--input_file', type=str)
38 | args = parser.parse_args()
39 |
40 | # TODO there should be a special key in the dict to distinguish the source model, like 'optimized_prompt' will be in the optimized version
41 | judge_key = 'optimized_prompt'
42 | cal_overall(args.input_file, judge_key)
--------------------------------------------------------------------------------
/src/evaluation/claude_score.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import multiprocessing
3 | from multiprocessing import Manager
4 | import json
5 | from tqdm import tqdm
6 | import os
7 | import time
8 | import pandas as pd
9 | import random
10 | import argparse
11 | from anthropic import Anthropic
12 |
13 | anthropic = Anthropic(
14 | api_key="Your-API-Key",
15 | )
16 |
17 |
18 | def claude_gen(messages, counter, error_count):
19 | responses = []
20 | for i, m in enumerate(messages):
21 | try:
22 | message = m['message']
23 | completion = anthropic.completions.create(
24 | model='claude-v1.3',
25 | max_tokens_to_sample=512,
26 | prompt=f"{message}",
27 | temperature=0.0
28 | )
29 | print(completion)
30 | resp = completion.completion
31 | m['response'] = resp
32 | # save to file
33 | with open(output_file, 'a', encoding='utf-8') as f:
34 | print(json.dumps(m, ensure_ascii=False), file=f)
35 |
36 | responses.append(resp)
37 |
38 | # Increment and print the counter
39 | counter += 1
40 | except Exception as e:
41 | error_count += 1
42 | print(e)
43 | print('running time:{} finished number:{} skipped number:{}'.format(time.time() - s_time, counter,
44 | error_count), end='\r')
45 |
46 | return responses
47 |
48 |
49 | def get_messages_list():
50 | if task_name.count("test_set") or task_name.count("dolly"):
51 | idx = "idx"
52 | elif task_name.count("self_instruct"):
53 | idx = "id"
54 | elif task_name.count("vicuna"):
55 | idx = "question_id"
56 | else:
57 | print("Not implemented")
58 | assert False
59 |
60 | evaluated = []
61 | with open(output_file, encoding='utf-8') as f:
62 | lines = f.readlines()
63 | for i in lines:
64 | evaluated.append(json.loads(i)['origin'])
65 |
66 | with open(input_file_a) as f:
67 | d_a = json.load(f)
68 |
69 | with open(input_file_b) as f:
70 | d_b = json.load(f)
71 |
72 | messages_list = []
73 |
74 | for i, j in zip(d_a, d_b):
75 | if i[idx] in evaluated:
76 | continue
77 | if random.randint(0, 1) == 0:
78 | option_a = i
79 | res_a = i['res']
80 | res_b = j['res']
81 | else:
82 | option_a = j
83 | res_a = j['res']
84 | res_b = i['res']
85 | if task_name.count("self_instruct") or task_name.count("dolly"):
86 | question = (i['instruction'] + '\n' + i['context']).strip()
87 | elif task_name.count("test_set"):
88 | question = i['context'].strip()
89 | elif task_name.count("vicuna"):
90 | question = i['text'].strip()
91 | else:
92 | print("Not implemented")
93 | assert False
94 | messages_list.append({'message': prompt.replace('{instruction}', question).replace('{output_1}', res_a).replace(
95 | '{output_2}', res_b),
96 | 'origin': i[idx],
97 | 'option_a': option_a,
98 | })
99 |
100 | return messages_list
101 |
102 |
103 | if __name__ == '__main__':
104 | parser = argparse.ArgumentParser()
105 |
106 | parser.add_argument('--input_file_a', type=str)
107 | parser.add_argument('--input_file_b', type=str)
108 | parser.add_argument('--task_name', type=str)
109 | parser.add_argument('--output_file', type=str)
110 | args = parser.parse_args()
111 |
112 | input_file_a = args.input_file_a
113 | input_file_b = args.input_file_b
114 | task_name = args.task_name
115 | output_file = args.output_file
116 |
117 | with open('./evaluation/ranking_prompt.txt') as f:
118 | lines = f.readlines()
119 | prompt = ''
120 | for i in lines:
121 | prompt = prompt + i
122 | if not os.path.exists(output_file):
123 | x = open(output_file, 'w')
124 | x.close()
125 | messages_list = get_messages_list()
126 | print("total num: ", len(messages_list))
127 | s_time = time.time()
128 | responses = claude_gen(messages_list, 0, 0)
129 |
--------------------------------------------------------------------------------
/src/evaluation/gpt4_score.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import multiprocessing
3 | from multiprocessing import Manager
4 | import json
5 | from tqdm import tqdm
6 | import os
7 | import time
8 | import pandas as pd
9 | import random
10 | import argparse
11 |
12 | API_KEY = 'Your-API-Key'
13 |
14 | HEADERS = {
15 | "Content-Type": "application/json",
16 | "Authorization": f"Bearer {API_KEY}"
17 | }
18 |
19 | API_URL = "https://api.openai.com/v1/chat/completions"
20 |
21 | def chat_gpt(messages, counter, error_count):
22 | responses = []
23 | for i, m in enumerate(messages):
24 | try:
25 | message = m['message']
26 | data = json.dumps({"model": "gpt-4", "messages": message, 'temperature': 0.0})
27 | response = requests.post(API_URL, headers=HEADERS, data=data)
28 | response_json = response.json()
29 | print(response_json)
30 | res = response_json['choices'][0]['message']['content']
31 | m['response'] = res
32 | # save to file
33 | with open(output_file, 'a', encoding='utf-8') as f:
34 | print(json.dumps(m, ensure_ascii=False), file=f)
35 |
36 | responses.append(response_json)
37 |
38 | # Increment and print the counter
39 | counter += 1
40 | except Exception as e:
41 | error_count += 1
42 | print(e)
43 | print('running time:{} finished number:{} skipped number:{}'.format(time.time()-s_time, counter, error_count), end='\r')
44 |
45 | return responses
46 |
47 |
48 | def get_messages_list():
49 |
50 | if task_name.count("test_set") or task_name.count("dolly"):
51 | idx = "idx"
52 | elif task_name.count("self_instruct"):
53 | idx = "id"
54 | elif task_name.count("vicuna"):
55 | idx = "question_id"
56 | else:
57 | print("idx Not implemented")
58 | assert False
59 |
60 | evaluated = []
61 | with open(output_file, encoding='utf-8') as f:
62 | lines = f.readlines()
63 | for i in lines:
64 | evaluated.append(json.loads(i)['origin'])
65 |
66 | with open(input_file_a) as f:
67 | d_a = json.load(f)
68 |
69 | with open(input_file_b) as f:
70 | d_b = json.load(f)
71 |
72 | messages_list = []
73 |
74 | for i,j in zip(d_a, d_b):
75 | assert (i[idx] == j[idx])
76 | if i[idx] in evaluated:
77 | continue
78 | if random.randint(0, 1) == 0:
79 | option_a = i
80 | res_a = i['res']
81 | res_b = j['res']
82 | else:
83 | option_a = j
84 | res_a = j['res']
85 | res_b = i['res']
86 | if task_name.count("self_instruct") or task_name.count("dolly"):
87 | question = (i['instruction']+'\n'+i['context']).strip()
88 | elif task_name.count("test_set"):
89 | question = i['context'].strip()
90 | elif task_name.count("vicuna"):
91 | question = i['text'].strip()
92 | else:
93 | print("Not implemented")
94 | assert False
95 | messages_list.append({'message': [
96 | {"role": 'system', "content": prompt['system_prompt']},
97 | {"role": "user", "content": prompt['prompt_template'].replace('{question}', question).replace('{answer_a}', res_a).replace('{answer_b}', res_b)}
98 | ],
99 | 'origin': i[idx],
100 | 'option_a': option_a,
101 | })
102 |
103 | return messages_list
104 |
105 |
106 | if __name__ == '__main__':
107 | parser = argparse.ArgumentParser()
108 |
109 | parser.add_argument('--input_file_a', type=str)
110 | parser.add_argument('--input_file_b', type=str)
111 | parser.add_argument('--task_name', type=str)
112 | parser.add_argument('--output_file', type=str)
113 | args = parser.parse_args()
114 |
115 | input_file_a = args.input_file_a
116 | input_file_b = args.input_file_b
117 | task_name = args.task_name
118 | output_file = args.output_file
119 |
120 | prompt = {"name": "pair-v2", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "[User Question]\n{question}\n\n[The Start of Assistant A's Answer]\n{answer_a}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{answer_b}\n[The End of Assistant B's Answer]", "description": "Prompt for general questions", "category": "general", "output_format": "[[A]]"}
121 | if not os.path.exists(output_file):
122 | x = open(output_file, 'w')
123 | x.close()
124 | messages_list = get_messages_list(task_name)
125 | print("total num: ", len(messages_list))
126 | s_time = time.time()
127 | responses = chat_gpt(messages_list, 0, 0)
128 |
--------------------------------------------------------------------------------
/src/evaluation/ranking_prompt.txt:
--------------------------------------------------------------------------------
1 |
2 |
3 | Human: I want you to create a leaderboard of different of large-language models. To do so, I will give you the instructions (prompts) given to the models, and the responses of two models. Please rank the models based on which responses would be preferred by humans. All inputs and outputs should be python dictionaries.
4 |
5 | Here is the prompt:
6 | {
7 | "instruction": """{instruction}""",
8 | }
9 |
10 | Here are the outputs of the models:
11 | [
12 | {
13 | "model": "model_1",
14 | "answer": """{output_1}"""
15 | },
16 | {
17 | "model": "model_2",
18 | "answer": """{output_2}"""
19 | }
20 | ]
21 |
22 | Now please rank the models by the quality of their answers, so that the model with rank 1 has the best output. Then return a list of the model names and ranks, i.e., produce the following output:
23 | [
24 | {'model': , 'rank': },
25 | {'model': , 'rank': }
26 | ]
27 |
28 | Your response must be a valid Python dictionary and should contain nothing else because we will directly execute it in Python. Please provide the ranking that the majority of humans would give.
29 |
30 | Assistant:
--------------------------------------------------------------------------------
/src/infer_example.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoTokenizer
2 | import torch
3 |
4 | # TODO change model path
5 | model_path = 'THUDM/BPO'
6 |
7 | prompt_template = "[INST] You are an expert prompt engineer. Please help me improve this prompt to get a more helpful and harmless response:\n{} [/INST]"
8 |
9 | device = 'cuda:0'
10 | model = AutoModelForCausalLM.from_pretrained(model_path).half().eval().to(device)
11 | # for 8bit
12 | # model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, load_in_8bit=True)
13 | tokenizer = AutoTokenizer.from_pretrained(model_path, add_prefix_space=True)
14 |
15 |
16 | def gen(input_text):
17 | prompt = prompt_template.format(input_text)
18 | model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
19 | output = model.generate(**model_inputs, max_new_tokens=1024, do_sample=True, top_p=0.9, temperature=0.6, num_beams=1)
20 | resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].strip()
21 |
22 | print("[Stable Optimization] ", resp)
23 |
24 |
25 | def gen_aggressive(input_text):
26 | texts = [input_text] * 5
27 | responses = []
28 | for text in texts:
29 | seed = torch.seed()
30 | torch.manual_seed(seed)
31 | prompt = prompt_template.format(text)
32 | min_length = len(tokenizer(prompt)['input_ids']) + len(tokenizer(text)['input_ids']) + 5
33 | model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
34 | bad_words_ids = [tokenizer(bad_word, add_special_tokens=False).input_ids for bad_word in ["[PROTECT]", "\n\n[PROTECT]", "[KEEP", "[INSTRUCTION]"]]
35 | # eos and \n
36 | eos_token_ids = [tokenizer.eos_token_id, 13]
37 | output = model.generate(**model_inputs, max_new_tokens=1024, do_sample=True, top_p=0.9, temperature=0.9, bad_words_ids=bad_words_ids, num_beams=1, eos_token_id=eos_token_ids, min_length=min_length)
38 | resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].split('[KE')[0].split('[INS')[0].split('[PRO')[0].strip()
39 | responses.append(resp)
40 |
41 | for i in responses:
42 | print("[Aggressive Optimization] ", i)
43 |
44 |
45 | text = 'how can I create a profile on Facebook?'
46 |
47 | # Stable optimization, this will sometimes maintain the original prompt
48 | gen(text)
49 |
50 | # Agressive optimization, this will refine the original prompt with a higher possibility
51 | # but there may be inappropriate changes
52 | gen_aggressive(text)
53 |
--------------------------------------------------------------------------------
/src/inference/llama2_infer.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoTokenizer
2 | from tqdm import tqdm
3 | import json
4 | import torch
5 | import time
6 | from collections import OrderedDict
7 |
8 | device = 'cuda:0'
9 |
10 |
11 | model_name = "Llama-2-7b-chat-hf"
12 | prompt_template = "[INST] {} [/INST]"
13 |
14 |
15 | model = AutoModelForCausalLM.from_pretrained(model_name).half().eval().to(device)
16 | tokenizer = AutoTokenizer.from_pretrained(model_name)
17 |
18 |
19 | # BPO-optimized prompts
20 | with open('dolly_eval_optimized.json') as f:
21 | data = json.load(f)
22 |
23 |
24 | with torch.no_grad():
25 | res = []
26 | for i in tqdm(data):
27 | input_text = prompt_template.format((i['optimized_prompt']).strip())
28 | model_inputs = tokenizer(input_text, return_tensors="pt").to(device)
29 |
30 | output = model.generate(**model_inputs, max_new_tokens=2048, do_sample=True, top_p=1.0, temperature=0.7)
31 | resp = tokenizer.decode(output[0], skip_special_tokens=True).split('[/INST]')[1].strip()
32 | i['res'] = resp
33 | res.append(i)
34 |
35 | with open('dolly_eval_optimized_llama2_7b_res.json', 'w', encoding='utf-8') as f:
36 | json.dump(res, f, indent=4, ensure_ascii=False)
37 |
--------------------------------------------------------------------------------
/src/training/config/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf8
2 | # @Time : 2023/5/12 20:39
3 | # @Author : tk
4 | # @FileName: __init__.py
5 |
6 | from config.main import *
7 |
8 |
--------------------------------------------------------------------------------
/src/training/config/constant_map.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time: 23:20
3 | # @Author: tk
4 | # @File:model_maps
5 |
6 | __model_path__ = {
7 | 'bloom-560m': {
8 | 'model_type': 'bloom',
9 | 'model_name_or_path': '/data/nlp/pre_models/torch/bloom/bloom-560m',
10 | 'config_name': '/data/nlp/pre_models/torch/bloom/bloom-560m/config.json',
11 | 'tokenizer_name': '/data/nlp/pre_models/torch/bloom/bloom-560m',
12 | },
13 | 'bloom-1b7': {
14 | 'model_type': 'bloom',
15 | 'model_name_or_path': '/data/nlp/pre_models/torch/bloom/bloom-1b7',
16 | 'config_name': '/data/nlp/pre_models/torch/bloom/bloom-1b7/config.json',
17 | 'tokenizer_name': '/data/nlp/pre_models/torch/bloom/bloom-1b7',
18 | },
19 | 'opt-350m': {
20 | 'model_type': 'opt',
21 | 'model_name_or_path': '/data/nlp/pre_models/torch/opt/opt-350m',
22 | 'config_name': '/data/nlp/pre_models/torch/opt/opt-350m/config.json',
23 | 'tokenizer_name': '/data/nlp/pre_models/torch/opt/opt-350m',
24 | },
25 |
26 | 'llama-7b-hf': {
27 | 'model_type': 'llama',
28 | 'model_name_or_path': '/data/nlp/pre_models/torch/llama/llama-7b-hf',
29 | 'config_name': '/data/nlp/pre_models/torch/llama/llama-7b-hf/config.json',
30 | 'tokenizer_name': '/data/nlp/pre_models/torch/llama/llama-7b-hf',
31 | },
32 |
33 | 'llama-13b-hf': {
34 | 'model_type': 'llama',
35 | 'model_name_or_path': '/cjl/pretrained_models/llama-13b-hf',
36 | 'config_name': '/cjl/pretrained_models/llama-13b-hf/config.json',
37 | 'tokenizer_name': '/cjl/pretrained_models/llama-13b-hf',
38 | },
39 |
40 | # TODO change model path
41 | 'Llama-2-7b-chat-hf':{
42 | 'model_type': 'llama',
43 | 'model_name_or_path': '/cjl/pretrained_models/Llama-2-7b-chat-hf',
44 | 'config_name': '/cjl/pretrained_models/Llama-2-7b-chat-hf/config.json',
45 | 'tokenizer_name': '/cjl/pretrained_models/Llama-2-7b-chat-hf',
46 | },
47 |
48 | 'Llama2-Chinese-7b-Chat':{
49 | 'model_type': 'llama',
50 | 'model_name_or_path': '/data/nlp/pre_models/torch/llama/Llama2-Chinese-7b-Chat',
51 | 'config_name': '/data/nlp/pre_models/torch/llama/Llama2-Chinese-7b-Chat/config.json',
52 | 'tokenizer_name': '/data/nlp/pre_models/torch/llama/Llama2-Chinese-7b-Chat',
53 | },
54 |
55 | 'Llama2-Chinese-13b-Chat':{
56 | 'model_type': 'llama',
57 | 'model_name_or_path': '/data/nlp/pre_models/torch/llama/Llama2-Chinese-13b-Chat',
58 | 'config_name': '/data/nlp/pre_models/torch/llama/Llama2-Chinese-13b-Chat/config.json',
59 | 'tokenizer_name': '/data/nlp/pre_models/torch/llama/Llama2-Chinese-13b-Chat',
60 | },
61 |
62 | 'chatyuan-7b': {
63 | 'model_type': 'llama',
64 | 'model_name_or_path': '/data/nlp/pre_models/torch/llama/ChatYuan-7B',
65 | 'config_name': '/data/nlp/pre_models/torch/llama/ChatYuan-7B/config.json',
66 | 'tokenizer_name': '/data/nlp/pre_models/torch/llama/ChatYuan-7B',
67 | },
68 | 'tigerbot-13b-chat': {
69 | 'model_type': 'llama',
70 | 'model_name_or_path': '/data/nlp/pre_models/torch/llama/tigerbot-13b-chat',
71 | 'config_name': '/data/nlp/pre_models/torch/llama/tigerbot-13b-chat/config.json',
72 | 'tokenizer_name': '/data/nlp/pre_models/torch/llama/tigerbot-13b-chat',
73 | },
74 | 'tigerbot-13b-chat-int4': {
75 | 'model_type': 'llama',
76 | 'model_name_or_path': '/data/nlp/pre_models/torch/llama/tigerbot-13b-chat-int4',
77 | 'config_name': '/data/nlp/pre_models/torch/llama/tigerbot-13b-chat-int4/config.json',
78 | 'tokenizer_name': '/data/nlp/pre_models/torch/llama/tigerbot-13b-chat-int4',
79 | },
80 |
81 | 'openbuddy-llama2-70b-v10.1': {
82 | 'model_type': 'llama',
83 | 'model_name_or_path': '/data/nlp/pre_models/torch/llama/openbuddy-llama2-70b-v10.1-bf16',
84 | 'config_name': '/data/nlp/pre_models/torch/llama/openbuddy-llama2-70b-v10.1-bf16/config.json',
85 | 'tokenizer_name': '/data/nlp/pre_models/torch/llama/openbuddy-llama2-70b-v10.1-bf16',
86 | },
87 |
88 |
89 |
90 | 'rwkv-4-430m-pile': {
91 | 'model_type': 'rwkv',
92 | 'model_name_or_path': '/data/nlp/pre_models/torch/rwkv/rwkv-4-430m-pile',
93 | 'config_name': '/data/nlp/pre_models/torch/rwkv/rwkv-4-430m-pile/config.json',
94 | 'tokenizer_name': '/data/nlp/pre_models/torch/rwkv/rwkv-4-430m-pile',
95 | },
96 |
97 | }
98 |
99 |
100 | # 'target_modules': ['query_key_value'], # bloom,gpt_neox
101 | # 'target_modules': ["q_proj", "v_proj"], #llama,opt,gptj,gpt_neo
102 | # 'target_modules': ['c_attn'], #gpt2
103 | # 'target_modules': ['project_q','project_v'] # cpmant
104 |
105 | train_target_modules_maps = {
106 | 't5': ['qkv_proj'],
107 | 'moss': ['qkv_proj'],
108 | 'chatglm': ['query_key_value'],
109 | 'bloom' : ['query_key_value'],
110 | 'gpt_neox' : ['query_key_value'],
111 | 'llama' : ["q_proj", "v_proj"],
112 | 'opt' : ["q_proj", "v_proj"],
113 | 'gptj' : ["q_proj", "v_proj"],
114 | 'gpt_neo' : ["q_proj", "v_proj"],
115 | 'gpt2' : ['c_attn'],
116 | 'cpmant' : ['project_q','project_v'],
117 | 'rwkv' : ['key','value','receptance'],
118 | }
119 |
120 |
121 | train_model_config = __model_path__['Llama-2-7b-chat-hf']
122 |
123 |
--------------------------------------------------------------------------------
/src/training/config/deepspeed.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_allow_untested_optimizer": true,
3 | "fp16": {
4 | "enabled": true,
5 | "auto_cast": false,
6 | "loss_scale": 0,
7 | "initial_scale_power": 16,
8 | "loss_scale_window": 1000,
9 | "hysteresis": 2,
10 | "min_loss_scale": 1
11 | },
12 | "zero_optimization": {
13 | "stage": 2,
14 | "allgather_partitions": true,
15 | "allgather_bucket_size": 5e8,
16 | "overlap_comm": false,
17 | "reduce_scatter": true,
18 | "reduce_bucket_size": 5e8,
19 | "contiguous_gradients" : true,
20 |
21 | "stage3_max_live_parameters" : 1e9,
22 | "stage3_max_reuse_distance" : 1e9,
23 | "stage3_prefetch_bucket_size" : 5e8,
24 | "stage3_param_persistence_threshold" : 1e6,
25 | "sub_group_size" : 1e12,
26 | "elastic_checkpoint" : true,
27 | "stage3_gather_16bit_weights_on_model_save": true,
28 | "ignore_unused_parameters": true,
29 | "round_robin_gradients": true
30 | }
31 | }
--------------------------------------------------------------------------------
/src/training/config/deepspeed_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "steps_per_print": 1,
3 | "gradient_clipping": 1.0,
4 | "optimizer": {
5 | "type": "AdamW",
6 | "params": {
7 | "lr": 0,
8 | "betas": [0.9, 0.999],
9 | "eps": 1e-8,
10 | "weight_decay": 1e-2
11 | }
12 | },
13 | "scheduler": {
14 | "type": "WarmupDecayLR",
15 | "params": {
16 | "warmup_min_lr": 0,
17 | "warmup_max_lr": 2e-5,
18 | "warmup_num_steps": "auto",
19 | "warmup_type": "linear",
20 | "total_num_steps": "auto"
21 | }
22 | },
23 | "zero_allow_untested_optimizer": true,
24 | "fp16": {
25 | "enabled": false
26 | },
27 | "zero_optimization": {
28 | "stage": 2,
29 | "allgather_partitions": true,
30 | "allgather_bucket_size": 5e8,
31 | "overlap_comm": false,
32 | "reduce_scatter": true,
33 | "reduce_bucket_size": 5e8,
34 | "contiguous_gradients": true,
35 | "stage3_max_live_parameters": 1e9,
36 | "stage3_max_reuse_distance": 1e9,
37 | "stage3_prefetch_bucket_size": 5e8,
38 | "stage3_param_persistence_threshold": 1e6,
39 | "sub_group_size": 1e12,
40 | "elastic_checkpoint": true,
41 | "stage3_gather_16bit_weights_on_model_save": true,
42 | "ignore_unused_parameters": true,
43 | "round_robin_gradients": true,
44 | "offload_optimizer": {
45 | "device": "cpu",
46 | "pin_memory": true
47 | }
48 | }
49 | }
--------------------------------------------------------------------------------
/src/training/config/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : ssbuild
3 | # @Time : 2023/5/31 14:43
4 | import json
5 | import os
6 | import torch
7 | from transformers import BitsAndBytesConfig
8 |
9 | # 全局配置
10 | global_args = {
11 | # 训练配置
12 | **dict(
13 | trainer_backend ='pl', # one of pl , hf
14 | enable_deepspeed = True,
15 | enable_ptv2 = False,
16 | enable_lora = False,
17 | load_in_bit = 0, # 4 load_in_4bit, 8 load_in_8bit other 0
18 | ),
19 | #与 transformers config合并
20 | "config_merge": {
21 | },
22 | # qlora
23 | "quantization_config": BitsAndBytesConfig(
24 | load_in_8bit =False,
25 | load_in_4bit = False,
26 | llm_int8_threshold=6.0,
27 | llm_int8_has_fp16_weight=False,
28 | bnb_4bit_compute_dtype=torch.float16 if not torch.cuda.is_bf16_supported() else torch.bfloat16,
29 | bnb_4bit_use_double_quant=True,
30 | bnb_4bit_quant_type="nf4",
31 | ),
32 | }
33 |
34 |
35 |
36 |
37 |
38 | if global_args["enable_lora"]:
39 | from config.sft_config_lora import train_info_args,train_info_args_hf,train_model_config
40 | elif global_args["enable_ptv2"]:
41 | from config.sft_config_ptv2 import train_info_args,train_info_args_hf,train_model_config
42 | else:
43 | from config.sft_config import train_info_args,train_info_args_hf,train_model_config
44 |
45 |
46 | if global_args["trainer_backend"] == "hf":
47 | train_info_args = train_info_args_hf
48 |
49 |
50 |
51 |
52 |
53 | def patch_args(train_info_args):
54 | assert global_args["enable_lora"] + global_args["enable_ptv2"] <= 1 , ValueError("lora ptv2 cannot open at same time")
55 |
56 | if global_args['quantization_config'] is not None:
57 | global_args['quantization_config'].load_in_4bit = global_args["load_in_bit"] == 4
58 | global_args['quantization_config'].load_in_8bit = global_args["load_in_bit"] == 8
59 | if global_args["load_in_bit"] == 0:
60 | global_args["quantization_config"] = None
61 |
62 | if global_args["enable_lora"]:
63 | #检查lora adalora是否开启
64 | if 'lora' not in train_info_args and 'adalora' not in train_info_args:
65 | raise ValueError('please config lora or adalora')
66 | if train_info_args.get('lora',{}).get('with_lora',False) and train_info_args.get('adalora',{}).get('with_lora',False):
67 | raise Exception('lora and adalora can set one at same time !')
68 |
69 | train_info_args.pop('prompt', None)
70 | elif global_args["enable_ptv2"]:
71 | train_info_args.pop('lora', None)
72 | train_info_args.pop('adalora', None)
73 | if hasattr(train_info_args,"gradient_checkpointing"):
74 | train_info_args.gradient_checkpointing = False
75 | else:
76 | train_info_args.pop('lora',None)
77 | train_info_args.pop('adalora', None)
78 | train_info_args.pop('prompt', None)
79 |
80 | # 预处理
81 | if 'rwkv' in train_info_args[ 'tokenizer_name' ].lower():
82 | train_info_args[ 'use_fast_tokenizer' ] = True
83 |
84 |
85 |
86 | patch_args(train_info_args)
87 |
88 |
89 | def get_deepspeed_config(precision='fp16'):
90 | '''
91 | lora prompt finetuning deepspeed_offload.json
92 | 普通 finetuning deepspeed.json
93 | '''
94 | # 是否开启deepspeed
95 | if not global_args["enable_deepspeed"]:
96 | return None
97 | precision = str(precision).lower()
98 | # 选择 deepspeed 配置文件
99 | is_need_update_config = False
100 | if global_args["enable_lora"]:
101 | is_need_update_config = True
102 | filename = os.path.join(os.path.dirname(__file__), 'deepspeed_offload.json')
103 | else:
104 | # filename = os.path.join(os.path.dirname(__file__), 'deepspeed.json')
105 | is_need_update_config = True
106 | filename = os.path.join(os.path.dirname(__file__), 'deepspeed_offload.json')
107 |
108 |
109 | with open(filename, mode='r', encoding='utf-8') as f:
110 | deepspeed_config = json.loads(f.read())
111 |
112 | #lora offload 同步优化器配置
113 | if is_need_update_config:
114 | optimizer = deepspeed_config.get('optimizer',None)
115 | if optimizer:
116 | if global_args["trainer_backend"] == 'hf':
117 | optimizer[ 'params' ][ 'betas' ] = (train_info_args.get('adam_beta1', 0.9),train_info_args.get('adam_beta2', 0.999),)
118 | optimizer[ 'params' ][ 'lr' ] = train_info_args.get('learning_rate', 2e-5)
119 | optimizer[ 'params' ][ 'eps' ] = train_info_args.get('adam_epsilon', 1e-8)
120 | # deepspeed_offload 优化器有效
121 | train_info_args[ 'optim' ] = optimizer[ 'type' ]
122 | else:
123 | optimizer['params']['betas'] = train_info_args.get('optimizer_betas', (0.9, 0.999))
124 | optimizer['params']['lr'] = train_info_args.get('learning_rate', 2e-5)
125 | optimizer['params']['eps'] = train_info_args.get('adam_epsilon', 1e-8)
126 | # deepspeed_offload 优化器有效
127 | train_info_args['optimizer'] = optimizer['type']
128 |
129 | if precision == 'bf16':
130 | if 'fp16' in deepspeed_config:
131 | deepspeed_config["fp16"]["enbale"] = False
132 | if 'bf16' in deepspeed_config:
133 | deepspeed_config["bf16"]["enbale"] = True
134 | else:
135 | deepspeed_config['bf16'] = {"enbale": True}
136 | elif precision == 'fp16':
137 | if 'bf16' in deepspeed_config:
138 | deepspeed_config["bf16"]["enbale"] = False
139 |
140 | return deepspeed_config
141 |
142 |
--------------------------------------------------------------------------------
/src/training/config/sft_config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2023/5/16 10:13
3 |
4 | import json
5 | import os
6 | import torch
7 | from config.constant_map import train_model_config
8 |
9 |
10 | train_info_args = {
11 | 'devices': [0, 1, 2, 3, 4, 5, 6, 7],
12 | # 'devices': [0],
13 | 'data_backend': 'parquet', #one of record lmdb arrow_stream arrow_file,parquet, 超大数据集可以使用 lmdb , 注 lmdb 存储空间比record大
14 |
15 | # 预训练模型配置
16 | **train_model_config,
17 |
18 |
19 | 'convert_onnx': False, # 转换onnx模型
20 | 'do_train': True,
21 | # TODO change training file path
22 | 'train_file': [ './data/train.jsonl'],
23 | 'max_epochs': 5,
24 | 'max_steps': -1,
25 |
26 | # *** optimizer
27 | # lamb,adamw_hf,adamw,adamw_torch,adamw_torch_fused,adamw_torch_xla,adamw_apex_fused,
28 | # adafactor,adamw_anyprecision,sgd,adagrad,adamw_bnb_8bit,adamw_8bit,lion,lion_8bit,lion_32bit,
29 | # paged_adamw_32bit,paged_adamw_8bit,paged_lion_32bit,paged_lion_8bit,
30 | # lamb_fused_dp adagrad_cpu_dp adam_cpu_dp adam_fused_dp
31 |
32 | # *** scheduler
33 | # linear,WarmupCosine,CAWR,CAL,Step,ReduceLROnPlateau, cosine,cosine_with_restarts,polynomial,
34 | # constant,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau
35 |
36 | # 'optimizer': 'lion',
37 | # 'scheduler_type': 'CAWR',
38 | # 'scheduler':{'T_mult': 1,
39 | # 'rewarm_epoch_num': 0.5, # 如果 max_epochs is not None !
40 | # # 'T_0': 50000, # 如果 max_epochs is None , 设定步数
41 | # 'verbose': False},
42 |
43 | # 'scheduler_type': 'linear',# one of [linear,WarmupCosine,CAWR,CAL,Step,ReduceLROnPlateau
44 | # 'scheduler': None,
45 |
46 | # 切换scheduler类型
47 | # 'scheduler_type': 'WarmupCosine',
48 | # 'scheduler': None,
49 |
50 | # 'scheduler_type': 'ReduceLROnPlateau',
51 | # 'scheduler': None,
52 |
53 | # 'scheduler_type': 'Step',
54 | # 'scheduler':{ 'decay_rate': 0.999,'decay_steps': 100,'verbose': True},
55 |
56 | # 'scheduler_type': 'CAWR',
57 | # 'scheduler':{'T_mult': 1, 'rewarm_epoch_num': 2, 'verbose': True},
58 |
59 | # 'scheduler_type': 'CAL',
60 | # 'scheduler': {'rewarm_epoch_num': 2,'verbose': True},
61 |
62 |
63 | 'optimizer_betas': (0.9, 0.999),
64 | 'train_batch_size': 4,
65 | 'eval_batch_size': 2,
66 | 'test_batch_size': 2,
67 | 'learning_rate': 0, #
68 | 'adam_epsilon': 1e-8,
69 | 'gradient_accumulation_steps': 1,
70 | 'max_grad_norm': 1.0,
71 | 'weight_decay': 0,
72 | 'warmup_steps': 0,
73 | 'output_dir': './output',
74 | 'max_seq_length': 512, #
75 | 'max_target_length': 100, # 预测最大长度, 保留字段
76 | 'use_fast_tokenizer': False,
77 | #'do_lower_case': False,
78 | "dataloader_drop_last": True,
79 | "dataloader_pin_memory":True,
80 | "dataloader_num_workers": 0,
81 | }
82 |
83 |
84 |
85 |
86 |
87 |
88 | train_info_args_hf = {
89 | 'data_backend': 'parquet', #one of record lmdb arrow_stream arrow_file,parquet, 超大数据集可以使用 lmdb , 注 lmdb 存储空间比record大
90 | # 预训练模型配置
91 | **train_model_config,
92 |
93 | "output_dir": "./outputs_hf",
94 | "overwrite_output_dir": True,
95 | "num_train_epochs": 20,
96 | "max_steps": -1,
97 | "save_safetensors": False,
98 | "save_strategy": "steps",
99 | "save_steps": 1000,
100 | "save_total_limit": 10,
101 | "seed": 66,
102 | "fp16": True,
103 | 'do_train': True,
104 | 'train_file': [ '/cjl/llm_finetuning/data/prompt_engineer/en/train.jsonl' ],
105 | 'do_eval': False,
106 | 'do_predict': False,
107 | "per_device_train_batch_size": 2,
108 | "per_device_eval_batch_size": 2,
109 | "gradient_accumulation_steps": 1,
110 | "evaluation_strategy": "no",
111 | "eval_steps": 100,
112 | "optim": "adamw_torch",
113 | "lr_scheduler_type": "cosine", # one of linear,cosine,cosine_with_restarts,polynomial,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau
114 | "torch_compile": False,
115 | "learning_rate": 2e-5,
116 | "adam_beta1": 0.9,
117 | "adam_beta2": 0.999,
118 | "adam_epsilon": 1e-8,
119 | "max_grad_norm": 1.0,
120 | "weight_decay": 0.,
121 | "warmup_ratio": 0.03,
122 | "logging_strategy": "steps",
123 | "logging_steps": 10,
124 | "tf32": False,
125 | "gradient_checkpointing": True,
126 | 'max_seq_length': 512, #
127 | 'max_target_length': 100, # 预测最大长度, 保留字段
128 | 'use_fast_tokenizer': False,
129 | # 'do_lower_case': False,
130 | "dataloader_drop_last": True,
131 | "dataloader_pin_memory": True,
132 | "dataloader_num_workers": 0,
133 |
134 | "log_level": "info", # 'info', 'warning', 'error' and 'critical , passive',
135 |
136 |
137 | }
138 |
139 |
140 |
141 |
--------------------------------------------------------------------------------
/src/training/config/sft_config_lora.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2023/5/16 10:13
3 |
4 | import json
5 | import os
6 | import torch
7 | from config.constant_map import train_model_config,train_target_modules_maps
8 |
9 |
10 | # 默认禁用lora 相关模块 , lora 和 adalora 只能同时启用一个
11 | lora_info_args = {
12 | 'with_lora': True, # 是否启用lora模块
13 | 'lora_type': 'lora',
14 | 'r': 8,
15 | 'target_modules': train_target_modules_maps[train_model_config['model_type']],
16 | 'lora_alpha': 32,
17 | 'lora_dropout': 0.1,
18 | 'fan_in_fan_out': False,
19 | 'bias': 'none', # Bias type for Lora. Can be 'none', 'all' or 'lora_only'"
20 | 'modules_to_save' : None, # "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. "
21 | 'layers_to_transform': None,
22 | 'layers_pattern': None,
23 | }
24 |
25 | adalora_info_args = {
26 | 'with_lora': False, # 是否启用adalora模块
27 | 'lora_type': 'adalora',
28 | 'r': 8,
29 | 'target_modules': train_target_modules_maps[train_model_config['model_type']],
30 | 'lora_alpha': 32,
31 | 'lora_dropout': 0.1,
32 | 'fan_in_fan_out': False,
33 | 'bias': 'none', # Bias type for Lora. Can be 'none', 'all' or 'lora_only'"
34 | 'modules_to_save' : None, # "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. "
35 | 'layers_to_transform': None,
36 | 'layers_pattern': None,
37 |
38 | 'target_r':8, # Target Lora matrix dimension.
39 | 'init_r': 12, #Intial Lora matrix dimension.
40 | 'tinit': 0, #The steps of initial warmup.
41 | 'tfinal': 0, #The steps of final warmup.
42 | 'deltaT': 1, #Step interval of rank allocation.
43 | 'beta1': 0.85, #Hyperparameter of EMA.
44 | 'beta2': 0.85, #Hyperparameter of EMA.
45 | 'orth_reg_weight': 0.5, #The orthogonal regularization coefficient.
46 | 'total_step': None, #The total training steps.
47 | 'rank_pattern': None, #The saved rank pattern.
48 | }
49 |
50 |
51 | train_info_args = {
52 | 'devices': 1,
53 | 'data_backend': 'parquet', #one of record lmdb arrow_stream arrow_file,parquet, 超大数据集可以使用 lmdb , 注 lmdb 存储空间比record大
54 | # 预训练模型配置
55 | **train_model_config,
56 | 'convert_onnx': False, # 转换onnx模型
57 | 'do_train': True,
58 | 'train_file': [ './data/finetune_train_examples.json'],
59 | 'max_epochs': 20,
60 | 'max_steps': -1,
61 |
62 | # *** optimizer
63 | # lamb,adamw_hf,adamw,adamw_torch,adamw_torch_fused,adamw_torch_xla,adamw_apex_fused,
64 | # adafactor,adamw_anyprecision,sgd,adagrad,adamw_bnb_8bit,adamw_8bit,lion,lion_8bit,lion_32bit,
65 | # paged_adamw_32bit,paged_adamw_8bit,paged_lion_32bit,paged_lion_8bit,
66 | # lamb_fused_dp adagrad_cpu_dp adam_cpu_dp adam_fused_dp
67 |
68 | # *** scheduler
69 | # linear,WarmupCosine,CAWR,CAL,Step,ReduceLROnPlateau, cosine,cosine_with_restarts,polynomial,
70 | # constant,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau
71 | 'optimizer': 'lion',
72 | 'scheduler_type': 'CAWR',
73 | 'scheduler':{'T_mult': 1,
74 | 'rewarm_epoch_num': 0.5, # 如果 max_epochs is not None !
75 | # 'T_0': 50000, # 如果 max_epochs is None , 设定步数
76 | 'verbose': False},
77 |
78 | # 'scheduler_type': 'linear',# one of [linear,WarmupCosine,CAWR,CAL,Step,ReduceLROnPlateau
79 | # 'scheduler': None,
80 |
81 | # 切换scheduler类型
82 | # 'scheduler_type': 'WarmupCosine',
83 | # 'scheduler': None,
84 |
85 | # 'scheduler_type': 'ReduceLROnPlateau',
86 | # 'scheduler': None,
87 |
88 | # 'scheduler_type': 'Step',
89 | # 'scheduler':{ 'decay_rate': 0.999,'decay_steps': 100,'verbose': True},
90 |
91 | # 'scheduler_type': 'CAWR',
92 | # 'scheduler':{'T_mult': 1, 'rewarm_epoch_num': 2, 'verbose': True},
93 |
94 | # 'scheduler_type': 'CAL',
95 | # 'scheduler': {'rewarm_epoch_num': 2,'verbose': True},
96 |
97 |
98 | 'optimizer_betas': (0.9, 0.999),
99 | 'train_batch_size': 2,
100 | 'eval_batch_size': 2,
101 | 'test_batch_size': 2,
102 | 'learning_rate': 2e-4, #
103 | 'adam_epsilon': 1e-8,
104 | 'gradient_accumulation_steps': 1,
105 | 'max_grad_norm': 1.0,
106 | 'weight_decay': 0,
107 | 'warmup_steps': 0,
108 | 'output_dir': './output',
109 | 'max_seq_length': 512, #
110 | 'max_target_length': 100, # 预测最大长度, 保留字段
111 | 'use_fast_tokenizer': False,
112 | #'do_lower_case': False,
113 |
114 | ############## lora模块
115 | 'lora': lora_info_args,
116 | 'adalora': adalora_info_args,
117 | "dataloader_drop_last": True,
118 | "dataloader_pin_memory": True,
119 | "dataloader_num_workers": 0,
120 |
121 | }
122 |
123 |
124 |
125 |
126 | train_info_args_hf = {
127 | 'data_backend': 'parquet',
128 | # one of record lmdb arrow_stream arrow_file,parquet, 超大数据集可以使用 lmdb , 注 lmdb 存储空间比record大
129 | # 预训练模型配置
130 | **train_model_config,
131 |
132 | "output_dir": "./outputs_hf",
133 | "overwrite_output_dir": True,
134 | "num_train_epochs": 20,
135 | "max_steps": -1,
136 | "save_safetensors": False,
137 | "save_strategy": "steps",
138 | "save_steps": 1000,
139 | "save_total_limit": 10,
140 | "seed": 42,
141 | "fp16": True,
142 | 'do_train': True,
143 | 'train_file': [ './data/finetune_train_examples.json'],
144 | 'do_eval': False,
145 | 'do_predict': False,
146 | "per_device_train_batch_size": 2,
147 | "per_device_eval_batch_size": 2,
148 | "gradient_accumulation_steps": 1,
149 | "evaluation_strategy": "no",
150 | "eval_steps": 100,
151 | "optim": "adamw_torch",
152 | "lr_scheduler_type": "cosine", # one of linear,cosine,cosine_with_restarts,polynomial,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau
153 | "torch_compile": False,
154 | "learning_rate": 2e-5,
155 | "adam_beta1": 0.9,
156 | "adam_beta2": 0.999,
157 | "adam_epsilon": 1e-8,
158 | "max_grad_norm": 1.0,
159 | "weight_decay": 0.,
160 | "warmup_ratio": 0.03,
161 | "logging_strategy": "steps",
162 | "logging_steps": 10,
163 | "tf32": False,
164 | "gradient_checkpointing": True,
165 | 'max_seq_length': 512, #
166 | 'max_target_length': 100, # 预测最大长度, 保留字段
167 | 'use_fast_tokenizer': False,
168 | # 'do_lower_case': False,
169 | "dataloader_drop_last": True,
170 | "dataloader_pin_memory": True,
171 | "dataloader_num_workers": 0,
172 |
173 | "log_level": "info", # 'info', 'warning', 'error' and 'critical , passive',
174 | ############## lora模块
175 | 'lora': lora_info_args,
176 | 'adalora': adalora_info_args,
177 |
178 | }
179 |
180 |
181 |
--------------------------------------------------------------------------------
/src/training/config/sft_config_ptv2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2023/5/16 10:13
3 |
4 | import json
5 | import os
6 |
7 | from config.constant_map import train_model_config
8 |
9 |
10 |
11 | prompt_info_args = {
12 | "with_prompt": True,
13 | "prompt_type": "prefix_tuning", # one of prompt_tuning,p_tuning,prefix_tuning,adaption_prompt
14 | "task_type": "causal_lm", # one of seq_cls,seq_2_seq_lm,causal_lm,token_cls
15 | "prefix_projection": False, # Whether to project the prefix tokens"
16 | "num_virtual_tokens": 32, # Number of virtual tokens
17 | # "token_dim": 2048, # The hidden embedding dimension of the base transformer model.
18 | # "num_transformer_submodules": 1, # The number of transformer submodules in the base transformer model.
19 | # "num_attention_heads" : 24, # The number of attention heads in the base transformer model.
20 | # "num_layers": 1, # The number of layers in the base transformer model.
21 | # "encoder_hidden_size": 2048, # The hidden size of the encoder
22 | }
23 |
24 | train_info_args = {
25 | 'devices': 1,
26 | 'data_backend': 'parquet', #one of record lmdb arrow_stream arrow_file,parquet, 超大数据集可以使用 lmdb , 注 lmdb 存储空间比record大
27 | # 预训练模型配置
28 | **train_model_config,
29 |
30 | 'convert_onnx': False, # 转换onnx模型
31 | 'do_train': True,
32 | 'train_file': [ './data/finetune_train_examples.json'],
33 | 'max_epochs': 20,
34 | 'max_steps': -1,
35 |
36 | # *** optimizer
37 | # lamb,adamw_hf,adamw,adamw_torch,adamw_torch_fused,adamw_torch_xla,adamw_apex_fused,
38 | # adafactor,adamw_anyprecision,sgd,adagrad,adamw_bnb_8bit,adamw_8bit,lion,lion_8bit,lion_32bit,
39 | # paged_adamw_32bit,paged_adamw_8bit,paged_lion_32bit,paged_lion_8bit,
40 | # lamb_fused_dp adagrad_cpu_dp adam_cpu_dp adam_fused_dp
41 |
42 | # *** scheduler
43 | # linear,WarmupCosine,CAWR,CAL,Step,ReduceLROnPlateau, cosine,cosine_with_restarts,polynomial,
44 | # constant,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau
45 |
46 | 'optimizer': 'lion',
47 | 'scheduler_type': 'CAWR',
48 | 'scheduler':{'T_mult': 1,'rewarm_epoch_num': 0.5,
49 | # 如果 max_epochs is not None !
50 | # 'T_0': 50000, # 如果 max_epochs is None , 设定步数
51 | 'verbose': False},
52 |
53 | # 'scheduler_type': 'linear',# one of [linear,WarmupCosine,CAWR,CAL,Step,ReduceLROnPlateau
54 | # 'scheduler': None,
55 |
56 | # 切换scheduler类型
57 | # 'scheduler_type': 'WarmupCosine',
58 | # 'scheduler': None,
59 |
60 | # 'scheduler_type': 'ReduceLROnPlateau',
61 | # 'scheduler': None,
62 |
63 | # 'scheduler_type': 'Step',
64 | # 'scheduler':{ 'decay_rate': 0.999,'decay_steps': 100,'verbose': True},
65 |
66 | # 'scheduler_type': 'CAWR',
67 | # 'scheduler':{'T_mult': 1, 'rewarm_epoch_num': 2, 'verbose': True},
68 |
69 | # 'scheduler_type': 'CAL',
70 | # 'scheduler': {'rewarm_epoch_num': 2,'verbose': True},
71 |
72 |
73 | 'optimizer_betas': (0.9, 0.999),
74 | 'train_batch_size': 2,
75 | 'eval_batch_size': 2,
76 | 'test_batch_size': 2,
77 | 'learning_rate': 5e-4, #
78 | 'adam_epsilon': 1e-8,
79 | 'gradient_accumulation_steps': 1,
80 | 'max_grad_norm': 1.0,
81 | 'weight_decay': 0,
82 | 'warmup_steps': 0,
83 | 'output_dir': './output',
84 | 'max_seq_length': 512, #
85 | 'max_target_length': 100, # 预测最大长度, 保留字段
86 | 'use_fast_tokenizer': False,
87 | #'do_lower_case': False,
88 | "dataloader_drop_last": True,
89 | "dataloader_pin_memory": True,
90 | "dataloader_num_workers": 0,
91 |
92 | ############## lora模块
93 | 'prompt': prompt_info_args,
94 |
95 | }
96 |
97 |
98 |
99 |
100 |
101 |
102 | train_info_args_hf = {
103 | 'data_backend': 'parquet',
104 | # one of record lmdb arrow_stream arrow_file,parquet, 超大数据集可以使用 lmdb , 注 lmdb 存储空间比record大
105 | # 预训练模型配置
106 | **train_model_config,
107 |
108 | "output_dir": "./outputs_hf",
109 | "overwrite_output_dir": True,
110 | "num_train_epochs": 20,
111 | "max_steps": -1,
112 | "save_safetensors": False,
113 | "save_strategy": "steps",
114 | "save_steps": 1000,
115 | "save_total_limit": 10,
116 | "seed": 42,
117 | "fp16": True,
118 | 'do_train': True,
119 | 'train_file': [ './data/finetune_train_examples.json'],
120 | 'do_eval': False,
121 | 'do_predict': False,
122 | "per_device_train_batch_size": 2,
123 | "per_device_eval_batch_size": 2,
124 | "gradient_accumulation_steps": 1,
125 | "evaluation_strategy": "no",
126 | "eval_steps": 100,
127 | "optim": "adamw_torch",
128 | "lr_scheduler_type": "cosine",# one of linear,cosine,cosine_with_restarts,polynomial,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau
129 | "torch_compile": False,
130 | "learning_rate": 2e-5,
131 | "adam_beta1": 0.9,
132 | "adam_beta2": 0.999,
133 | "adam_epsilon": 1e-8,
134 | "max_grad_norm": 1.0,
135 | "weight_decay": 0.,
136 | "warmup_ratio": 0.03,
137 | "logging_strategy": "steps",
138 | "logging_steps": 10,
139 | "tf32": False,
140 | "gradient_checkpointing": True,
141 | 'max_seq_length': 512, #
142 | 'max_target_length': 100, # 预测最大长度, 保留字段
143 | 'use_fast_tokenizer': False,
144 | # 'do_lower_case': False,
145 | "dataloader_drop_last": True,
146 | "dataloader_pin_memory": True,
147 | "dataloader_num_workers": 0,
148 |
149 | "log_level": "info", # 'info', 'warning', 'error' and 'critical , passive',
150 | ############## lora模块
151 | 'prompt': prompt_info_args,
152 | }
153 |
154 |
--------------------------------------------------------------------------------
/src/training/data/finetune_train_examples.jsonl:
--------------------------------------------------------------------------------
1 | {"id": 0, "paragraph": [{"q": "Make me a cup of tea.", "a": "Please provide me with instructions on how to make a cup of tea."}]}
2 | {"id": 1, "paragraph": [{"q": "Give me 5 first date ideas", "a": "Provide 5 first date ideas with reasons for each suggestion."}]}
3 |
--------------------------------------------------------------------------------
/src/training/data_processer.py:
--------------------------------------------------------------------------------
1 | # @Time : 2023/3/25 18:36
2 | # @Author : tk
3 | import copy
4 | from enum import Enum
5 | import numpy as np
6 | from transformers import PreTrainedTokenizer
7 |
8 | DEFAULT_PAD_TOKEN = "[PAD]"
9 | DEFAULT_EOS_TOKEN = ""
10 | DEFAULT_BOS_TOKEN = ""
11 | DEFAULT_UNK_TOKEN = ""
12 |
13 | class DataStrategy(Enum):
14 | tunction = 1
15 | slidding = 2
16 |
17 |
18 |
19 | def build_template_llama(query, answer = None,prefix=None, history=None):
20 | return query
21 |
22 |
23 | def build_template_default(query, answer = None,prefix=None, history=None):
24 | prompt = "[INST] You are an expert prompt engineer. Please help me improve this prompt to get a more helpful and harmless response:\n{} [/INST]".format(query)
25 | return prompt
26 |
27 | def build_template_tiger(query,answer = None,prefix=None, history=None):
28 | prompt = prefix or ''
29 | tok_ins = "\n\n### Instruction:\n"
30 | tok_res = "\n\n### Response:\n"
31 | if history is not None:
32 | for q,a in history:
33 | prompt += "{}{}{}{}".format(tok_ins,q,tok_res,a)
34 |
35 | prompt += "{}{}{}".format(tok_ins, query, tok_res)
36 | if answer is not None:
37 | prompt += answer
38 | return prompt
39 |
40 |
41 | #切换模板
42 | build_template = build_template_default
43 | # build_template = build_template_llama
44 |
45 |
46 | class TokenIdsMaker:
47 | @classmethod
48 | def final(cls, tokenizer, input_ids, labels, max_seq_length):
49 | seqlen = np.asarray(len(input_ids), dtype=np.int32)
50 | pad_len = max_seq_length - seqlen
51 | input_ids = np.asarray(input_ids, dtype=np.int32)
52 | attention_mask = np.asarray([1] * len(input_ids), dtype=np.int32)
53 | labels = np.asarray(labels, dtype=np.int32)
54 | if pad_len:
55 | pad_val = tokenizer.eos_token_id
56 | input_ids = np.pad(input_ids, (0, pad_len), 'constant', constant_values=(pad_val, pad_val))
57 | attention_mask = np.pad(attention_mask, (0, pad_len), 'constant', constant_values=(0, 0))
58 | labels = np.pad(labels, (0, pad_len), 'constant', constant_values=(-100, -100))
59 | d = {
60 | 'input_ids': input_ids,
61 | 'attention_mask': attention_mask,
62 | 'labels': labels,
63 | 'seqlen': seqlen
64 | }
65 | return d
66 | @classmethod
67 | def tunction(cls, tokenizer: PreTrainedTokenizer, config, sup, max_seq_length, examples):
68 | sptoken = [config.bos_token_id]
69 | ds = []
70 | prefix, examples = examples
71 | max_a_b_len = 0
72 | for sid, (q, a) in enumerate(examples):
73 | a_ids = tokenizer.encode(text=build_template(q,prefix=prefix,history=examples[:sid]), add_special_tokens=False)
74 | # from IPython import embed
75 | # embed()
76 | b_ids = tokenizer.encode(text=a, add_special_tokens=False)
77 | max_a_b_len = max(max_a_b_len, len(a_ids) + len(b_ids) + len(sptoken) + 1)
78 | while len(a_ids) + len(b_ids) > max_seq_length - len(sptoken) - 1:
79 | if len(b_ids) > len(a_ids):
80 | b_ids.pop(-1)
81 | else:
82 | a_ids.pop(0)
83 | b_ids += [config.eos_token_id]
84 | input_ids = a_ids + b_ids
85 | labels = copy.deepcopy(input_ids) if not sup else [-100] * len(a_ids) + copy.deepcopy(b_ids)
86 | input_ids = sptoken + input_ids
87 | labels = sptoken + labels if not sup else [-100] * len(sptoken) + labels
88 | assert len(input_ids) <= max_seq_length
89 | ds.append(cls.final(tokenizer, input_ids, labels, max_seq_length))
90 | return ds, max_a_b_len
91 |
92 |
93 | @classmethod
94 | def slidding(cls, tokenizer: PreTrainedTokenizer,config,stride,max_seq_length, examples,
95 | sliding_size=None,
96 | src_max_length=-1,
97 | dst_max_length=-1,
98 | sup=True):
99 | sptoken = [config.bos_token_id]
100 | if sliding_size is None or sliding_size < 0:
101 | sliding_size = max_seq_length - len(sptoken)
102 |
103 | assert sliding_size <= max_seq_length - len(sptoken)
104 |
105 | ds = []
106 | prefix, examples = examples
107 | for sid, (q, a) in enumerate(examples):
108 | a_ids = tokenizer.encode(text=build_template(q, prefix=prefix, history=examples[:sid]),add_special_tokens=False)
109 | b_ids = tokenizer.encode(text=a, add_special_tokens=False)
110 | if src_max_length and src_max_length > 0:
111 | a_ids = a_ids[:src_max_length]
112 | if dst_max_length and dst_max_length > 0:
113 | b_ids = b_ids[:dst_max_length]
114 |
115 | b_ids += [config.eos_token_id]
116 | input_ids_qa = a_ids + b_ids
117 | labels_all = copy.deepcopy(input_ids_qa) if not sup else [-100] * len(a_ids) + b_ids
118 |
119 | pos = 0
120 | while pos < len(input_ids_qa):
121 | input_ids = input_ids_qa[pos:pos + max_seq_length - len(sptoken)]
122 | labels = labels_all[pos:pos + max_seq_length - len(sptoken)]
123 |
124 | pos += sliding_size
125 | if np.all(np.asarray(labels) == -100):
126 | continue
127 |
128 | input_ids = sptoken + input_ids
129 | labels = sptoken + labels if not sup else [-100] * len(sptoken) + labels
130 | ds.append(cls.final(tokenizer, input_ids, labels, max_seq_length))
131 | return ds
132 |
133 |
134 |
--------------------------------------------------------------------------------
/src/training/data_utils.py:
--------------------------------------------------------------------------------
1 | # @Time : 2023/1/22 16:22
2 | # @Author : tk
3 | # @FileName: data_utils.py
4 |
5 | import copy
6 | import json
7 | import os
8 | import random
9 | import typing
10 | import numpy as np
11 | import torch
12 | from deep_training.data_helper import DataHelper, ModelArguments, TrainingArguments,TrainingArgumentsHF, DataArguments
13 | from aigc_zoo.model_zoo.llm.llm_model import PetlArguments,LoraConfig,PromptArguments
14 | from fastdatasets.record import load_dataset as Loader, RECORD, WriterObject, gfile
15 | from transformers import PreTrainedTokenizer, HfArgumentParser, PretrainedConfig
16 | from data_processer import DataStrategy, TokenIdsMaker, build_template, DEFAULT_PAD_TOKEN, DEFAULT_EOS_TOKEN, \
17 | DEFAULT_BOS_TOKEN, DEFAULT_UNK_TOKEN
18 | from config import *
19 | from module_setup import module_setup
20 |
21 |
22 | module_setup()
23 |
24 | data_conf = {
25 | 'strategy': DataStrategy.tunction, # 数据策略选项
26 | DataStrategy.tunction: {
27 | 'sup': True, # 是否监督模式
28 | },
29 |
30 | DataStrategy.slidding: {
31 | 'stride': int(train_info_args['max_seq_length'] / 3 * 2),
32 | 'sup': True, # 是否监督模式
33 | "src_max_length": train_info_args['max_seq_length'] - 10,
34 | "dst_max_length": None,
35 | }
36 |
37 | }
38 |
39 |
40 |
41 | def preprocess(text):
42 | return text
43 |
44 | def postprocess(text):
45 | return text
46 |
47 |
48 | class NN_DataHelper(DataHelper):
49 | index = 1
50 | data_len = []
51 |
52 | def __init__(self, *args, **kwargs):
53 | super(NN_DataHelper, self).__init__(*args, **kwargs)
54 | assert data_conf[DataStrategy.slidding]['stride'] > 0
55 |
56 | def load_tokenizer_and_config(self, *args, **kwargs):
57 | ret = super().load_tokenizer_and_config(*args, **kwargs)
58 | self._preprocess_tokenizer_config()
59 | return ret
60 |
61 | def _preprocess_tokenizer_config(self):
62 | model_args = self.model_args
63 | tokenizer = self.tokenizer
64 | config = self.config
65 |
66 |
67 |
68 | if "llama" in model_args.model_type.lower():
69 | special_tokens_dict = dict()
70 | # from IPython import embed
71 | # embed()
72 | # if tokenizer.pad_token is None:
73 | # special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
74 | if tokenizer.eos_token is None:
75 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
76 | if tokenizer.bos_token is None:
77 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
78 | if tokenizer.unk_token is None:
79 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
80 |
81 | _ = tokenizer.add_special_tokens(special_tokens_dict)
82 |
83 | # tokenizer.add_special_tokens({
84 | # "eos_token": DEFAULT_EOS_TOKEN,
85 | # "bos_token": DEFAULT_BOS_TOKEN,
86 | # "unk_token": DEFAULT_UNK_TOKEN,
87 | # })
88 | # if tokenizer.pad_token_id is None or tokenizer.pad_token_id == -1:
89 | # tokenizer.pad_token_id = tokenizer.eos_token_id
90 |
91 | if tokenizer.pad_token is None:
92 | tokenizer.add_special_tokens({
93 | "pad_token": tokenizer.eos_token,
94 | })
95 | if config.pad_token_id is None or config.pad_token_id == -1:
96 | config.pad_token_id = tokenizer.eos_token_id
97 |
98 |
99 |
100 | if config.decoder_start_token_id is None:
101 | config.decoder_start_token_id = config.bos_token_id
102 |
103 | if config.decoder_start_token_id != tokenizer.bos_token_id:
104 | print('*' * 30, 'config.decoder_start_token_id != tokenizer.bos_token_id !!!')
105 |
106 | assert config.decoder_start_token_id == config.bos_token_id
107 |
108 | def on_data_ready(self):
109 | self.index = -1
110 |
111 | # 切分词
112 | def on_data_process(self, data: typing.Any, mode: str):
113 | self.index += 1
114 |
115 | tokenizer: PreTrainedTokenizer
116 | config = self.config
117 | max_seq_length = self.max_seq_length_dict[mode]
118 | tokenizer = self.tokenizer
119 |
120 | examples = data
121 | # from IPython import embed
122 | # embed()
123 | # exit()
124 |
125 | strategy = data_conf['strategy']
126 | if strategy == DataStrategy.tunction:
127 | ds,l = TokenIdsMaker.tunction(tokenizer, config=config, max_seq_length=max_seq_length, examples=examples,
128 | **data_conf[strategy])
129 | self.data_len.append(l)
130 | elif strategy == DataStrategy.slidding:
131 | ds = TokenIdsMaker.slidding(tokenizer, config=config, max_seq_length=max_seq_length, examples=examples,
132 | **data_conf[strategy])
133 |
134 | else:
135 | raise ValueError('Invalid strategy', strategy)
136 | if not ds:
137 | return None
138 |
139 | if self.index < 3:
140 | print(ds[0])
141 | # from IPython import embed
142 | # embed()
143 | # exit()
144 | return ds
145 |
146 | def _get_paragraph(self,lines):
147 | D = []
148 | for line_id, line in enumerate(lines):
149 | jd = json.loads(line)
150 | if not jd:
151 | continue
152 | paragraph = jd['paragraph']
153 | if line_id < 10:
154 | print(paragraph)
155 |
156 | prefix = jd.get('p', '')
157 | paragraph = [(preprocess(session['q']),
158 | preprocess('\n'.join(session['a'])) if isinstance(session['a'], list) else preprocess(
159 | session['a']))
160 | for session in paragraph]
161 | sub = []
162 | # 自行做模板
163 | # TODO: make a template for llama2
164 | # https://gpus.llm-utils.org/llama-2-prompt-template/
165 | for (q,a) in paragraph:
166 | if not len(a):
167 | continue
168 | assert len(a), ValueError('answer cannot empty')
169 | sub.append((q, a))
170 | D.append((prefix, copy.deepcopy(sub)))
171 | # from IPython import embed
172 | # embed()
173 | # exit()
174 |
175 | sub.clear()
176 | return D
177 |
178 | def _get_messages(self,lines):
179 | D = []
180 | for line_id, line in enumerate(lines):
181 | jd = json.loads(line)
182 | if not jd:
183 | continue
184 | conversations = jd['conversations']
185 | if line_id < 10:
186 | print(conversations)
187 |
188 | paragraph = []
189 | prefix = ''
190 | pair = [None,None]
191 | for m in conversations:
192 | if m["from"] == 'user':
193 | pair[0] = preprocess(m["value"])
194 | elif m["from"] == 'assistant':
195 | pair[1] = preprocess(m["value"])
196 | elif m["from"] == 'system':
197 | prefix = preprocess(m["value"])
198 | if pair[0] is not None and pair[1] is not None:
199 | paragraph.append(tuple(pair))
200 | pair[0],pair[1] = None,None
201 |
202 | sub = []
203 | # 自行做模板
204 | for (q, a) in paragraph:
205 | assert len(a), ValueError('answer cannot empty')
206 | sub.append((q, a))
207 | D.append((prefix, copy.deepcopy(sub)))
208 | sub.clear()
209 | return D
210 | # 读取文件
211 | def on_get_corpus(self, files: typing.List, mode: str):
212 | D = []
213 | for file in files:
214 | with open(file, mode='r', encoding='utf-8', newline='\n') as f:
215 | lines = f.readlines()
216 | is_new = False
217 | if len(lines) > 0:
218 | is_new = 'conversations' in json.loads(lines[0])
219 | if is_new:
220 | D.extend(self._get_messages(lines))
221 | else:
222 | D.extend(self._get_paragraph(lines))
223 | return D
224 |
225 | def collate_fn(self, batch):
226 | o = {}
227 | for i, b in enumerate(batch):
228 | if i == 0:
229 | for k in b:
230 | o[k] = [torch.tensor(b[k])]
231 | else:
232 | for k in b:
233 | o[k].append(torch.tensor(b[k]))
234 | for k in o:
235 | o[k] = torch.stack(o[k])
236 |
237 | maxlen = torch.max(o.pop('seqlen'))
238 | o['input_ids'] = o['input_ids'][:, :maxlen]
239 | o['attention_mask'] = o['attention_mask'][:, :maxlen]
240 | o['labels'] = o['labels'][:, :maxlen].long()
241 | return o
242 |
243 | def make_dataset_all(self):
244 | data_args = self.data_args
245 | # schema for arrow parquet
246 | schema = {
247 | "input_ids": "int32_list",
248 | "attention_mask": "int32_list",
249 | "labels": "int32_list",
250 | "seqlen": "int32_list",
251 | }
252 | # 缓存数据集
253 | if data_args.do_train:
254 | self.make_dataset_with_args(data_args.train_file, mixed_data=False, shuffle=True, mode='train',
255 | schema=schema)
256 | if data_args.do_eval:
257 | self.make_dataset_with_args(data_args.eval_file, mode='eval', schema=schema)
258 | if data_args.do_test:
259 | self.make_dataset_with_args(data_args.test_file, mode='test', schema=schema)
260 |
261 |
262 |
263 | if __name__ == '__main__':
264 |
265 | if global_args["trainer_backend"] == "hf":
266 | parser = HfArgumentParser((ModelArguments, TrainingArgumentsHF, DataArguments, PetlArguments, PromptArguments),
267 | conflict_handler='resolve')
268 | model_args, training_args, data_args, lora_args, prompt_args = parser.parse_dict(train_info_args,
269 | allow_extra_keys=True, )
270 | else:
271 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, PetlArguments, PromptArguments))
272 | model_args, training_args, data_args, _, _ = parser.parse_dict(train_info_args)
273 |
274 | dataHelper = NN_DataHelper(model_args, training_args, data_args)
275 | tokenizer, config, _, _ = dataHelper.load_tokenizer_and_config(config_kwargs={"torch_dtype": torch.float16})
276 |
277 |
278 |
279 |
280 | # 缓存数据集
281 | # 检测是否存在 output/dataset_0-train.record ,不存在则制作数据集
282 | dataHelper.make_dataset_all()
283 | print(np.mean(dataHelper.data_len))
284 | print(np.percentile(dataHelper.data_len, [25, 50, 99]))
285 | print(np.max(dataHelper.data_len))
286 |
287 |
288 | # def shuffle_records(record_filenames, outfile, compression_type='GZIP'):
289 | # print('shuffle_records record...')
290 | # options = RECORD.TFRecordOptions(compression_type=compression_type)
291 | # dataset_reader = Loader.RandomDataset(record_filenames, options=options, with_share_memory=True)
292 | # data_size = len(dataset_reader)
293 | # all_example = []
294 | # for i in tqdm(range(data_size), desc='load records'):
295 | # serialized = dataset_reader[i]
296 | # all_example.append(serialized)
297 | # dataset_reader.close()
298 | #
299 | # shuffle_idx = list(range(data_size))
300 | # random.shuffle(shuffle_idx)
301 | # writer = WriterObject(outfile, options=options)
302 | # for i in tqdm(shuffle_idx, desc='shuffle record'):
303 | # example = all_example[i]
304 | # writer.write(example)
305 | # writer.close()
306 | #
307 | #
308 | # # 对每个record 再次打乱
309 | # for filename in dataHelper.train_files:
310 | # shuffle_records(filename, filename)
311 |
--------------------------------------------------------------------------------
/src/training/infer.py:
--------------------------------------------------------------------------------
1 | # @Time : 2023/4/2 22:49
2 | # @Author : tk
3 | # @FileName: infer
4 |
5 | import torch
6 | from deep_training.data_helper import ModelArguments
7 | from transformers import HfArgumentParser
8 | from data_utils import train_info_args, NN_DataHelper, get_deepspeed_config
9 | from aigc_zoo.model_zoo.llm.llm_model import MyTransformer
10 | from aigc_zoo.utils.llm_generate import Generate
11 | from aigc_zoo.model_zoo.llm.llm_model import RotaryNtkScaledArguments,RotaryLinearScaledArguments # aigc-zoo 0.1.20
12 |
13 | deep_config = get_deepspeed_config()
14 |
15 |
16 | def infer_tiger(model,tokenizer,max_input_length=512):
17 | tok_ins = "\n\n### Instruction:\n"
18 | tok_res = "\n\n### Response:\n"
19 | prompt_input = tok_ins + "{instruction}" + tok_res
20 |
21 | generation_config = {
22 | "do_sample": True,
23 | "eos_token_id": 2,
24 | "max_length": max_input_length,
25 | "pad_token_id": 60514,
26 | "repetition_penalty": 1.1,
27 | "temperature": 0.3,
28 | "transformers_version": "4.31.0"
29 | }
30 | text_list = ["写一个诗歌,关于冬天",
31 | "晚上睡不着应该怎么办",
32 | "从南京到上海的路线",
33 | ]
34 |
35 | for input in text_list:
36 | sess_text = ''
37 |
38 | query_text = input.strip()
39 | sess_text += tok_ins + query_text
40 | input_text = prompt_input.format_map({'instruction': sess_text.split(tok_ins, 1)[1]})
41 | inputs = tokenizer(input_text, return_tensors='pt', truncation=True, max_length=max_input_length)
42 | inputs = {k: v.to(model.device) for k, v in inputs.items()}
43 | output = model.generate(**inputs, **generation_config)
44 | output_str = tokenizer.decode(output[0], skip_special_tokens=False, spaces_between_special_tokens=False)
45 | answer = output_str.rsplit(tok_res, 1)[1].strip()
46 | if answer.endswith(tokenizer.eos_token):
47 | answer = answer.rsplit(tokenizer.eos_token, 1)[0].strip()
48 |
49 | print('input', input)
50 | print('output', answer)
51 |
52 | if __name__ == '__main__':
53 |
54 |
55 | parser = HfArgumentParser((ModelArguments,))
56 | (model_args,) = parser.parse_dict(train_info_args, allow_extra_keys=True)
57 |
58 | dataHelper = NN_DataHelper(model_args)
59 | tokenizer, config, _,_= dataHelper.load_tokenizer_and_config()
60 |
61 | enable_ntk = False
62 | rope_args = None
63 | if enable_ntk and config.model_type == 'llama':
64 | rope_args = RotaryNtkScaledArguments(name='rotary_emb',max_position_embeddings=2048, alpha=4) # 扩展 8k
65 | # rope_args = RotaryLinearScaledArguments(name='rotary_emb',max_position_embeddings=2048, scale=4) # 扩展 8k
66 |
67 |
68 | pl_model = MyTransformer(config=config, model_args=model_args,torch_dtype=config.torch_dtype,rope_args=rope_args)
69 | model = pl_model.get_llm_model()
70 | model = model.eval()
71 | if hasattr(model,'quantize'):
72 | # 支持llama llama2量化
73 | if not model.quantized:
74 | # 按需修改,目前只支持 4/8 bit 量化 , 可以保存量化模型
75 | model.half().quantize(4).cuda()
76 | # 保存量化权重
77 | # model.save_pretrained('llama2-7b-chat-int4',max_shard_size="2GB")
78 | # exit(0)
79 | else:
80 | # 已经量化
81 | model.half().cuda()
82 | else:
83 | model.half().cuda()
84 |
85 | if train_info_args['model_name_or_path'].lower().find('tiger') >=0:
86 | infer_tiger(model,tokenizer)
87 | else:
88 | text_list = ["写一个诗歌,关于冬天",
89 | "晚上睡不着应该怎么办",
90 | "从南京到上海的路线",
91 | ]
92 | for input in text_list:
93 | response = Generate.generate(model, query=input, tokenizer=tokenizer, max_length=512,
94 | eos_token_id=config.eos_token_id,
95 | do_sample=False, top_p=0.7, temperature=0.95, )
96 | print('input', input)
97 | print('output', response)
--------------------------------------------------------------------------------
/src/training/infer_finetuning.py:
--------------------------------------------------------------------------------
1 | # @Time : 2023/4/2 22:49
2 | # @Author : tk
3 | # @FileName: infer
4 |
5 | import torch
6 | from deep_training.data_helper import ModelArguments
7 | from transformers import HfArgumentParser, AutoConfig
8 | from data_utils import train_info_args, NN_DataHelper, get_deepspeed_config,build_template
9 | from aigc_zoo.model_zoo.llm.llm_model import MyTransformer
10 | from aigc_zoo.utils.llm_generate import Generate
11 | import json
12 | from tqdm import tqdm
13 |
14 | deep_config = get_deepspeed_config()
15 |
16 |
17 | if __name__ == '__main__':
18 | # TODO add input file and output file (requires json)
19 | # or you could implement it yourself
20 | input_file = '../../data/data_construction/examples_ctx.json'
21 | output_file = '../../data/data_construction/examples_ctx_optimized_gen.json'
22 |
23 | # optimized on evaluation set
24 | # input_file = '../../data/testset/dolly_eval.json'
25 | # output_file = '../../data/testset/dolly_eval_optimized.json'
26 |
27 | parser = HfArgumentParser((ModelArguments,))
28 | (model_args,) = parser.parse_dict(train_info_args, allow_extra_keys=True)
29 |
30 | dataHelper = NN_DataHelper(model_args)
31 | tokenizer, _, _,_= dataHelper.load_tokenizer_and_config()
32 |
33 |
34 | config = AutoConfig.from_pretrained('./output/best_ckpt')
35 | pl_model = MyTransformer(config=config, model_args=model_args,torch_dtype=config.torch_dtype,)
36 |
37 | # deepspeed 权重使用转换脚本命令
38 | # 一般根据时间排序选最新的权重文件夹
39 | # cd best_ckpt/last
40 | # python zero_to_fp32.py . ../last.ckpt
41 |
42 | train_weight = './output/best_ckpt'
43 |
44 | pl_model.load_sft_weight(train_weight,strict=True)
45 |
46 | # 保存hf权重
47 | # config.save_pretrained('convert/')
48 |
49 | # 保存sft p-tuning-v2 权重
50 | # pl_model.save_sft_weight('convert/pytorch_model_sft_ptv2.bin')
51 |
52 | # 保存sft权重
53 | # pl_model.save_sft_weight('convert/pytorch_model_sft.bin')
54 |
55 | model = pl_model.get_llm_model()
56 |
57 | model.eval().half().cuda()
58 |
59 |
60 | with open(input_file, encoding='utf-8') as f:
61 | text_list = json.load(f)[:]
62 |
63 | gen_res = []
64 |
65 | for input in tqdm(text_list[:]):
66 |
67 | response = Generate.generate(model, query=build_template((input['instruction']+'\n'+input['context']).strip()), tokenizer=tokenizer, max_new_tokens=1024,
68 | eos_token_id=config.eos_token_id,
69 | do_sample=True, top_p=0.9, temperature=0.6, num_beams=1)
70 |
71 | input['gen_res'] = response.strip()
72 | gen_res.append(input)
73 |
74 | with open(output_file, 'w', encoding='utf-8') as f:
75 | json.dump(gen_res, f, indent=4, ensure_ascii=False)
76 |
--------------------------------------------------------------------------------
/src/training/infer_lora_finetuning.py:
--------------------------------------------------------------------------------
1 | # @Time : 2023/4/2 22:49
2 | # @Author : tk
3 | # @FileName: infer_lora_finetuning
4 | import os
5 | import torch
6 | from deep_training.data_helper import ModelArguments
7 | from transformers import HfArgumentParser,AutoConfig
8 | from data_utils import train_info_args, NN_DataHelper,global_args,build_template
9 | from aigc_zoo.model_zoo.llm.llm_model import MyTransformer,PetlArguments,PromptArguments
10 | from aigc_zoo.utils.llm_generate import Generate
11 |
12 |
13 | if __name__ == '__main__':
14 | train_info_args['seed'] = None
15 | parser = HfArgumentParser((ModelArguments,))
16 | (model_args,) = parser.parse_dict(train_info_args, allow_extra_keys=True)
17 |
18 |
19 | dataHelper = NN_DataHelper(model_args)
20 | tokenizer, _, _, _ = dataHelper.load_tokenizer_and_config()
21 |
22 |
23 | # 一般根据时间排序选最新的权重文件夹
24 | ckpt_dir = './best_ckpt/last'
25 |
26 | config = AutoConfig.from_pretrained(ckpt_dir)
27 | lora_args = PetlArguments.from_pretrained(ckpt_dir)
28 |
29 | assert lora_args.inference_mode == True
30 |
31 | new_num_tokens = config.vocab_size
32 | if config.task_specific_params is not None and config.task_specific_params.get('vocab_size',None) is not None:
33 | config.vocab_size = config.task_specific_params['vocab_size']
34 |
35 | pl_model = MyTransformer(config=config, model_args=model_args,
36 | lora_args=lora_args,
37 | torch_dtype=config.torch_dtype,
38 | new_num_tokens=new_num_tokens,
39 | # load_in_8bit=global_args["load_in_8bit"],
40 | # # device_map="auto",
41 | # device_map = {"":0} # 第一块卡
42 | )
43 |
44 | # 加载lora权重
45 | pl_model.load_sft_weight(ckpt_dir)
46 |
47 | pl_model.eval().half().cuda()
48 |
49 | enable_merge_weight = False
50 |
51 | if enable_merge_weight:
52 | # 合并lora 权重 保存
53 | pl_model.save_sft_weight(os.path.join(ckpt_dir, 'pytorch_model_merge.bin'), merge_lora_weight=True)
54 | else:
55 | model = pl_model.get_llm_model()
56 |
57 | text_list = ["写一个诗歌,关于冬天",
58 | "晚上睡不着应该怎么办",
59 | "从南京到上海的路线",
60 | ]
61 | for input in text_list:
62 | response = Generate.generate(model, query=build_template(input), tokenizer=tokenizer, max_length=512,
63 | eos_token_id=config.eos_token_id,
64 | do_sample=False, top_p=0.7, temperature=0.95, )
65 | print('input', input)
66 | print('output', response)
--------------------------------------------------------------------------------
/src/training/infer_muti_lora_finetuning.py:
--------------------------------------------------------------------------------
1 | # @Time : 2023/4/2 22:49
2 | # @Author : tk
3 | # @FileName: infer_lora_finetuning
4 | import os
5 | import torch
6 | from deep_training.data_helper import ModelArguments
7 | from transformers import HfArgumentParser,AutoConfig
8 | from data_utils import train_info_args, NN_DataHelper,global_args,build_template
9 | from aigc_zoo.model_zoo.llm.llm_model import MyTransformer,\
10 | PetlArguments,PromptArguments,PetlModel
11 | from aigc_zoo.utils.llm_generate import Generate
12 |
13 |
14 | if __name__ == '__main__':
15 | train_info_args['seed'] = None
16 | parser = HfArgumentParser((ModelArguments,))
17 | (model_args,) = parser.parse_dict(train_info_args, allow_extra_keys=True)
18 |
19 |
20 | dataHelper = NN_DataHelper(model_args)
21 | tokenizer, _, _, _ = dataHelper.load_tokenizer_and_config()
22 |
23 |
24 | # 一般根据时间排序选最新的权重文件夹
25 | ckpt_dir = './best_ckpt/last'
26 |
27 | config = AutoConfig.from_pretrained(ckpt_dir)
28 | lora_args = PetlArguments.from_pretrained(ckpt_dir)
29 |
30 | assert lora_args.inference_mode == True
31 |
32 | new_num_tokens = config.vocab_size
33 | if config.task_specific_params is not None and config.task_specific_params.get('vocab_size',None) is not None:
34 | config.vocab_size = config.task_specific_params['vocab_size']
35 |
36 | pl_model = MyTransformer(config=config, model_args=model_args,
37 | lora_args=lora_args,
38 | torch_dtype=config.torch_dtype,
39 | new_num_tokens=new_num_tokens,
40 | # load_in_8bit=global_args["load_in_8bit"],
41 | # # device_map="auto",
42 | # device_map = {"":0} # 第一块卡
43 | )
44 |
45 | # 加载多个lora权重
46 | pl_model.load_sft_weight(ckpt_dir, adapter_name="default")
47 |
48 | # 加载多个lora权重
49 | # pl_model.load_sft_weight(ckpt_dir, adapter_name="yourname")
50 |
51 | # 加载多个lora权重
52 | # pl_model.load_sft_weight(ckpt_dir, adapter_name="yourname")
53 |
54 | pl_model.eval().half().cuda()
55 |
56 | # backbone model replaced PetlModel
57 | lora_model: PetlModel = pl_model.backbone
58 |
59 | text_list = ["写一个诗歌,关于冬天",
60 | "晚上睡不着应该怎么办",
61 | "从南京到上海的路线",
62 | ]
63 |
64 | # 基准模型推理
65 | with lora_model.disable_adapter():
66 | for input in text_list:
67 | # lora_model 调用子对象方法
68 | response = Generate.generate(lora_model, query=build_template(input), tokenizer=tokenizer, max_length=512,
69 | eos_token_id=config.eos_token_id,
70 | do_sample=False, top_p=0.7, temperature=0.95, )
71 | print('input', input)
72 | print('output', response)
73 |
74 | lora_model.set_adapter(adapter_name='default')
75 |
76 | for input in text_list:
77 | # lora_model 调用子对象方法
78 | response = Generate.generate(lora_model, query=build_template(input), tokenizer=tokenizer, max_length=512,
79 | eos_token_id=config.eos_token_id,
80 | do_sample=False, top_p=0.7, temperature=0.95, )
81 | print('input', input)
82 | print('output', response)
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/src/training/infer_ptuning.py:
--------------------------------------------------------------------------------
1 | # @Time : 2023/4/2 22:49
2 | # @Author : tk
3 | # @FileName: infer_ptuning
4 | import os
5 | import torch
6 | from deep_training.data_helper import ModelArguments
7 | from transformers import HfArgumentParser,AutoConfig
8 | from data_utils import train_info_args, NN_DataHelper,build_template
9 | from aigc_zoo.model_zoo.llm.llm_model import MyTransformer,PromptArguments
10 | from aigc_zoo.utils.llm_generate import Generate
11 |
12 | if __name__ == '__main__':
13 | train_info_args['seed'] = None
14 | parser = HfArgumentParser((ModelArguments,))
15 | (model_args,) = parser.parse_dict(train_info_args,allow_extra_keys=True)
16 |
17 | dataHelper = NN_DataHelper(model_args)
18 | tokenizer, _, _, _ = dataHelper.load_tokenizer_and_config(config_kwargs={"torch_dtype": torch.float16})
19 |
20 |
21 | train_weight_dir = './best_ckpt/last'
22 | config = AutoConfig.from_pretrained(train_weight_dir)
23 | prompt_args = PromptArguments.from_pretrained(train_weight_dir)
24 |
25 | assert prompt_args.inference_mode == True
26 |
27 | new_num_tokens = config.vocab_size
28 | if config.task_specific_params is not None and config.task_specific_params.get('vocab_size', None) is not None:
29 | config.vocab_size = config.task_specific_params['vocab_size']
30 |
31 | pl_model = MyTransformer(config=config, model_args=model_args,
32 | prompt_args=prompt_args,
33 | new_num_tokens=new_num_tokens,
34 | )
35 | # 加载sft权重
36 | pl_model.load_sft_weight(train_weight_dir)
37 |
38 | pl_model.eval().half().cuda()
39 |
40 | model = pl_model.get_llm_model()
41 |
42 | #基础模型精度
43 | model.base_model_torch_dtype = torch.half
44 |
45 | text_list = ["写一个诗歌,关于冬天",
46 | "晚上睡不着应该怎么办",
47 | "从南京到上海的路线"]
48 | for input in text_list:
49 | for input in text_list:
50 | response = Generate.generate(model, query=build_template(input), tokenizer=tokenizer, max_length=512,
51 | eos_token_id=config.eos_token_id,
52 | do_sample=False, top_p=0.7, temperature=0.95, )
53 | print('input', input)
54 | print('output', response)
--------------------------------------------------------------------------------
/src/training/module_setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : ssbuild
3 | # @Time : 2023/8/16 16:03
4 |
5 | from deep_training.utils.hf import register_transformer_model,register_transformer_config
6 | from transformers import AutoModelForCausalLM
7 | from deep_training.nlp.models.rellama.modeling_llama import LlamaForCausalLM
8 | __all__ = [
9 | "module_setup"
10 | ]
11 |
12 | def module_setup():
13 | # 导入模型
14 | #register_transformer_config(XverseConfig)
15 | register_transformer_model(LlamaForCausalLM, AutoModelForCausalLM)
--------------------------------------------------------------------------------
/src/training/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import logging
3 | import os.path
4 | import torch
5 | from deep_training.data_helper import ModelArguments, DataArguments, TrainingArguments
6 | from deep_training.trainer.pl.modelcheckpoint import ModelCheckpointEx
7 | from lightning import Trainer
8 | from lightning.pytorch.callbacks import LearningRateMonitor
9 | from lightning.pytorch.strategies import DeepSpeedStrategy
10 | from transformers import HfArgumentParser
11 | from data_utils import NN_DataHelper, train_info_args, get_deepspeed_config, global_args
12 | from aigc_zoo.model_zoo.llm.llm_model import MyTransformer, PetlArguments, LoraConfig, PromptArguments
13 |
14 |
15 | assert global_args["trainer_backend"] == "pl"
16 |
17 | if __name__ == '__main__':
18 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, PetlArguments,PromptArguments))
19 | model_args, training_args, data_args, lora_args,prompt_args = parser.parse_dict(train_info_args)
20 | lora_args = lora_args.config
21 | prompt_args = prompt_args.config
22 |
23 | output_weight_dir = data_args.output_dir + '/best_ckpt'
24 |
25 |
26 | dataHelper = NN_DataHelper(model_args, training_args, data_args)
27 | config_kwargs = {"torch_dtype": torch.float16}
28 | if global_args['config_merge']:
29 | config_kwargs.update(global_args['config_merge'])
30 |
31 | tokenizer, config, _, _ = dataHelper.load_tokenizer_and_config(config_kwargs=config_kwargs)
32 |
33 | dataHelper.make_dataset_all()
34 |
35 | is_bf16_supported = torch.cuda.is_bf16_supported()
36 | # 精度 根据实际情况做调整
37 | if is_bf16_supported:
38 | precision = 'bf16'
39 | else:
40 | precision = '16'
41 |
42 | if global_args["quantization_config"] is not None and global_args["quantization_config"].load_in_8bit:
43 | precision = "32"
44 |
45 | deepspeed_config = get_deepspeed_config(precision)
46 | strategy = 'ddp' if torch.cuda.device_count() > 1 else 'auto'
47 | if deepspeed_config is not None and len(deepspeed_config):
48 | warmup_ratio = 0.1
49 | with open(train_info_args['train_file'][0]) as f:
50 | total_steps = len(f.readlines()) * train_info_args['max_epochs']
51 | total_steps /= len(train_info_args['devices']) * train_info_args['train_batch_size']
52 | deepspeed_config['scheduler']['params']['warmup_num_steps'] = int(total_steps*warmup_ratio)
53 | deepspeed_config['scheduler']['params']['total_num_steps'] = int(total_steps)
54 | print("total steps: ", int(total_steps))
55 | print("steps per epoch: ", int(total_steps/train_info_args['max_epochs']))
56 | # from IPython import embed
57 | # embed()
58 | # exit()
59 | strategy = DeepSpeedStrategy(config=deepspeed_config, )
60 | checkpoint_callback = ModelCheckpointEx(
61 | # monitor='loss',
62 | dirpath=output_weight_dir,
63 | save_weights_only=True,
64 | save_last=False,
65 | save_top_k=-1,
66 | # every_n_train_steps=2000 // training_args.gradient_accumulation_steps,
67 | every_n_epochs=1,
68 | lora_args=lora_args,
69 | prompt_args=prompt_args,
70 | )
71 |
72 |
73 | trainer = Trainer(
74 | callbacks=[checkpoint_callback, LearningRateMonitor(logging_interval='step')],
75 | max_epochs=training_args.max_epochs,
76 | max_steps=training_args.max_steps,
77 | # max_steps=1,
78 | accelerator="gpu",
79 | devices=data_args.devices,
80 | enable_progress_bar=True,
81 | default_root_dir=data_args.output_dir,
82 | gradient_clip_val=training_args.max_grad_norm,
83 | accumulate_grad_batches=training_args.gradient_accumulation_steps,
84 | num_sanity_val_steps=0,
85 | strategy=strategy,
86 | log_every_n_steps=1,
87 | # lora int8 precision='32'
88 | precision=precision,# 可以自行尝试 "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed"
89 | )
90 |
91 |
92 | transformer_args = dict(config=config, model_args=model_args, training_args=training_args, lora_args=lora_args, prompt_args=prompt_args,
93 | quantization_config=global_args["quantization_config"],
94 | device_map={"": trainer.local_rank} if trainer.world_size > 1 else "auto",
95 | torch_dtype=torch.float16,
96 | new_num_tokens=len(tokenizer), # 可能扩充词
97 | )
98 |
99 | if transformer_args["quantization_config"] is None:
100 | transformer_args.pop("device_map")
101 |
102 | pl_model = MyTransformer(**transformer_args)
103 |
104 | config.save_pretrained(output_weight_dir)
105 |
106 | # 加载sft权重
107 | # pl_model.load_sft_weight('./best_ckpt/best.pt',is_trainable=True)
108 |
109 | # pl_model = pl_model.float() if not is_bf16_supported else pl_model.bfloat16()
110 |
111 | def dataset_loader_filter_fn(dataset):
112 | print('*' * 30, 'total', len(dataset))
113 | return dataset
114 |
115 |
116 | train_datasets = dataHelper.load_distributed_random_sampler(
117 | dataHelper.train_files,
118 | with_load_memory=data_args.data_backend == 'record',
119 | collate_fn=dataHelper.collate_fn,
120 | batch_size=training_args.train_batch_size,
121 | drop_last=training_args.dataloader_drop_last, # 多卡建议扔掉
122 | num_processes=trainer.world_size, process_index=trainer.global_rank,
123 | dataset_loader_filter_fn=dataset_loader_filter_fn,
124 | num_workers=training_args.dataloader_num_workers,
125 | pin_memory=training_args.dataloader_pin_memory,
126 | )
127 |
128 | if train_datasets is not None:
129 | trainer.fit(pl_model, train_dataloaders=train_datasets)
130 |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/src/training/train_hf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : ssbuild
3 | # @Time : 2023/9/25 12:29
4 |
5 |
6 | import logging
7 | import math
8 | import os
9 | import sys
10 | import datasets
11 | import torch
12 | import transformers
13 | from deep_training.trainer.hf.trainer import TrainerHF
14 | from transformers import (
15 | HfArgumentParser,
16 | default_data_collator,
17 | set_seed,
18 | )
19 | from transformers.trainer_utils import get_last_checkpoint
20 | from transformers.utils import check_min_version, send_example_telemetry
21 | from transformers.utils.versions import require_version
22 | from data_utils import NN_DataHelper, train_info_args, get_deepspeed_config, global_args
23 | from aigc_zoo.model_zoo.llm.llm_model import MyTransformer, PetlArguments, LoraConfig, PromptArguments
24 | from deep_training.data_helper import ModelArguments, DataArguments,TrainingArgumentsHF
25 |
26 | assert global_args["trainer_backend"] == "hf"
27 |
28 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
29 | check_min_version("4.33.2")
30 |
31 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
32 |
33 | logger = logging.getLogger(__name__)
34 |
35 | # Setup logging
36 | logging.basicConfig(
37 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
38 | datefmt="%m/%d/%Y %H:%M:%S",
39 | handlers=[logging.StreamHandler(sys.stdout)],
40 | )
41 |
42 | def main():
43 | training_args: TrainingArgumentsHF
44 | parser = HfArgumentParser((ModelArguments, TrainingArgumentsHF, DataArguments, PetlArguments, PromptArguments),
45 | conflict_handler='resolve')
46 | model_args, training_args, data_args, lora_args, prompt_args = parser.parse_dict(train_info_args,allow_extra_keys=True,)
47 | lora_args = lora_args.config
48 | prompt_args = prompt_args.config
49 |
50 | if training_args.should_log:
51 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
52 | transformers.utils.logging.set_verbosity_info()
53 |
54 | log_level = training_args.get_process_log_level()
55 | logger.setLevel(log_level)
56 | datasets.utils.logging.set_verbosity(log_level)
57 | transformers.utils.logging.set_verbosity(log_level)
58 | transformers.utils.logging.enable_default_handler()
59 | transformers.utils.logging.enable_explicit_format()
60 |
61 | dataHelper = NN_DataHelper(model_args, training_args, data_args)
62 | config_kwargs = {"torch_dtype": torch.float16}
63 | if global_args['config_merge']:
64 | config_kwargs.update(global_args['config_merge'])
65 |
66 | tokenizer, config, _, _ = dataHelper.load_tokenizer_and_config(config_kwargs=config_kwargs)
67 |
68 | with training_args.main_process_first(desc="make_dataset_all"):
69 | dataHelper.make_dataset_all()
70 |
71 | is_bf16_supported = torch.cuda.is_bf16_supported()
72 | # 精度 根据实际情况做调整
73 | if is_bf16_supported:
74 | precision = 'bf16'
75 | else:
76 | precision = '16'
77 |
78 | if global_args["quantization_config"] is not None and global_args["quantization_config"].load_in_8bit:
79 | precision = "32"
80 |
81 |
82 | if str(precision) == '16':
83 | training_args.fp16 = True
84 | elif str(precision) == 'bf16':
85 | training_args.bf16 = True
86 | else:
87 | training_args.fp16 = False
88 | training_args.bf16 = False
89 |
90 | deepspeed_config = get_deepspeed_config(precision)
91 | if deepspeed_config:
92 | training_args.deepspeed = deepspeed_config
93 |
94 | # Log on each process the small summary:
95 | logger.warning(
96 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
97 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
98 | )
99 | logger.info(f"Training/evaluation parameters {training_args}")
100 |
101 | # Detecting last checkpoint.
102 | last_checkpoint = None
103 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
104 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
105 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
106 | raise ValueError(
107 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
108 | "Use --overwrite_output_dir to overcome."
109 | )
110 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
111 | logger.info(
112 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
113 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
114 | )
115 |
116 | # Set seed before initializing model.
117 | set_seed(training_args.seed)
118 |
119 | world_size,local_rank,process_index = training_args.world_size,training_args.local_rank,training_args.process_index
120 |
121 | transformer_args = dict(config=config, model_args=model_args, training_args=training_args, lora_args=lora_args,
122 | prompt_args=prompt_args,
123 | quantization_config=global_args["quantization_config"],
124 | device_map={"": local_rank} if world_size > 1 else "auto",
125 | torch_dtype=torch.float16,
126 | new_num_tokens=len(tokenizer), # 可能扩充词
127 | )
128 |
129 | if transformer_args["quantization_config"] is None:
130 | transformer_args.pop("device_map")
131 |
132 | pl_model = MyTransformer(**transformer_args)
133 |
134 | config.save_pretrained(training_args.output_dir)
135 |
136 | # 加载sft权重
137 | # pl_model.load_sft_weight('./best_ckpt/best.pt',is_trainable=True)
138 |
139 | pl_model = pl_model.float() if not is_bf16_supported else pl_model.bfloat16()
140 |
141 | train_datasets = None
142 | if training_args.do_train:
143 | train_datasets = dataHelper.load_distributed_random_sampler(
144 | dataHelper.train_files,
145 | with_load_memory=data_args.data_backend == 'record',
146 | collate_fn=dataHelper.collate_fn,
147 | batch_size=training_args.train_batch_size,
148 | drop_last=training_args.dataloader_drop_last, # 多卡建议扔掉
149 | num_processes=world_size, process_index=process_index,
150 | num_workers = training_args.dataloader_num_workers,
151 | pin_memory = training_args.dataloader_pin_memory,
152 | )
153 |
154 |
155 |
156 | # Initialize our Trainer
157 | trainer = TrainerHF(
158 | model=pl_model,
159 | args=training_args,
160 | train_dataset=train_datasets,
161 | tokenizer=tokenizer,
162 | # Data collator will default to DataCollatorWithPadding, so we change it.
163 | data_collator=default_data_collator,
164 | )
165 |
166 | # Training
167 | if training_args.do_train:
168 | checkpoint = None
169 | if training_args.resume_from_checkpoint is not None:
170 | checkpoint = training_args.resume_from_checkpoint
171 | elif last_checkpoint is not None:
172 | checkpoint = last_checkpoint
173 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
174 | trainer.save_model() # Saves the tokenizer too for easy upload
175 |
176 | metrics = train_result.metrics
177 | metrics["train_samples"] = len(train_datasets)
178 | trainer.log_metrics("train", metrics)
179 | trainer.save_metrics("train", metrics)
180 | trainer.save_state()
181 |
182 |
183 |
184 |
185 | def _mp_fn(index):
186 | # For xla_spawn (TPUs)
187 | main()
188 |
189 |
190 | if __name__ == "__main__":
191 | main()
192 |
--------------------------------------------------------------------------------