├── .gitignore ├── LICENSE ├── README.md ├── assets └── images │ ├── Snipaste_2023-05-09_00-00-35.png │ ├── Snipaste_2023-05-09_00-04-06.png │ ├── Snipaste_2023-05-09_00-06-44.png │ ├── Snipaste_2023-05-09_00-18-19.png │ ├── Snipaste_2023-05-09_00-24-15.png │ ├── Snipaste_2023-05-09_00-51-19.png │ ├── Snipaste_2023-05-09_01-03-22.png │ ├── Snipaste_2023-05-09_12-07-31.png │ ├── Snipaste_2023-05-09_12-17-42.png │ ├── Snipaste_2023-05-09_12-20-14.png │ ├── Snipaste_2023-05-09_12-22-49.png │ └── total_score.png ├── merge_model.py ├── scripts ├── instruction_tuning │ └── run_train.sh └── pre_training │ └── run_train.sh ├── src ├── __init__.py ├── configuration │ ├── __init__.py │ ├── basic_argsparser.py │ ├── constants.py │ ├── pl_argsparser.py │ └── task │ │ └── config_args.py ├── data_processing │ ├── data_processing.py │ ├── gen_instruction_multi.py │ ├── gen_pretrain_data.py │ ├── gen_tuning_data.py │ ├── generate_instruction.py │ ├── prompt.txt │ ├── test.py │ └── utils.py ├── flash_bloom │ ├── __init__.py │ ├── attention.py │ ├── bert_padding.py │ ├── bloom_flash_attn_monkey_patch.py │ ├── flash_attn_interface.py │ ├── flash_attn_wrapper.py │ └── llama_flash_attn_monkey_patch.py ├── models │ ├── __init__.py │ ├── basic_pl_trainer.py │ ├── basic_trainer.py │ ├── lightning_base.py │ └── task │ │ ├── __init__.py │ │ ├── gpt_code.py │ │ └── llama_finetune.py ├── modules │ ├── __init__.py │ ├── pl_callbacks.py │ └── task │ │ ├── conversation.py │ │ ├── data_loader.py │ │ └── dataset.py ├── requirements.txt ├── serve │ └── hanfei_serve.py ├── utils │ ├── 230525模型对比表.xlsx │ ├── __init__.py │ ├── compare_score_count.png │ ├── dataset_utils.py │ ├── file_utils.py │ ├── gen_utils.py │ ├── huggingface_helper.py │ ├── metric_utils.py │ ├── nlg_eval_utils.py │ ├── plot.py │ ├── print_utils.py │ ├── string_utils.py │ ├── task │ │ ├── __init__.py │ │ ├── event_analyzer_utils.py │ │ ├── event_utils.py │ │ ├── model_utils.py │ │ └── stat_utils.py │ ├── total_score.png │ ├── total_score_ratio.png │ └── wrapper.py └── web │ └── hanfei_app.py ├── tasks ├── download_hf_models.sh └── task │ ├── CoT_test.py │ ├── convert_metadata.py │ ├── corpus_stat.py │ ├── test.py │ └── train.py ├── toy_hanfei.py └── zero_to_fp32.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # custom 132 | *.bin 133 | .DS_Store 134 | logging_*.* 135 | data/ 136 | model/ 137 | .idea/ 138 | output/ 139 | wandb/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HanFei 2 | ## 介绍 3 | HanFei-1.0(韩非)是国内首个**全参数**训练的法律大模型,参数量7b,主要功能包括:法律问答、多轮对话、撰写文章、检索(敬请期待)等。 4 | 5 | [beta版本](http://siat.yang42.com:10185/) 6 | 7 | 例子1 8 | 9 | ![example 1](./assets/images/Snipaste_2023-05-09_12-17-42.png) 10 | 11 | 例子2 12 | 13 | ![example 2](./assets/images/Snipaste_2023-05-09_00-00-35.png) 14 | 15 | ## 数据 16 | `注:目前只开源hanfei-1.0微调数据,不开源预训练数据。` 17 | 18 | ### 数据处理 19 | 20 | src/data_processing/gen_pretrain_data.py 生成预训练数据 21 | 22 | src/data_processing/gen_instruction_multi.py 生成指令数据、合并微调数据 23 | 24 | ### 预训练数据 25 | 26 | 数据组成:案例、法规、起诉状、法律新闻等。 27 | 28 | 数据量:约60G数据,每条2048个token。 29 | 30 | ### 微调数据 31 | 32 | #### hanfei 1.0 33 | 第一版使用规则筛选 34 | 35 | | 数据类型 | 文件名称 | 数据量 | 36 | | :------------------- | :------------------------- | :----- | 37 | | 中文通用指令 |zh_general_instruction.json | 5.3w | 38 | | 中文法律指令 | zh_law_instruction.json | 4.1w | 39 | | 中文通用对话 |zh_general_conversation.json | 5.5w | 40 | | 中文法律对话 | zh_law_conversation.json | 5.6w | 41 | | 中文法律问答数据 | zh_law_qa.json | 5w | 42 | 43 | 44 | #### 数据、模型下载链接 45 | 百度网盘: 46 | 47 | 链接:https://pan.baidu.com/s/1PkRXUo9sNRQmoXHcW7Aeeg?pwd=d6t5 48 | 49 | 提取码:d6t5 50 | 51 | #### hanfei 2.0(开发中) 52 | 第二版使用人工筛选 53 | 54 | 55 | ### 模型评估数据 56 | 57 | | 数据类型 | 数据路径 | 数据说明 | 数据量 | 58 | | :------- | :----------------- | :---------------------- | :-------------------- | 59 | | 法律问题 | data/evaluation_dataset | 包含劳动、婚姻等9个板块 | 150条 | 60 | 61 | 62 | ## 评估指标 63 | 本次采用人工评估的方法,针对每一个法律咨询问题,Hanfei、BLOOMz、ChatGPT 3 个语言模型分别生成回答,我们聘请了专业的律师,为各个语言模型生成的回答打分。 64 | 65 | + 评估指标 1:(0-10 分)0 分最差,10分最好 66 | 67 | 本次总共评估了150个问题,我们将每个模型的得分求和,用总得分衡量模型回答的质量,评估结果如下图所示: 68 | 69 | ![total_score](./assets/images/total_score.png) 70 | 71 | ## 训练 72 | 73 | ### 环境要求 74 | 75 | A100/A800 * 8 76 | 77 | ### 训练命令 78 | ```sh 79 | # Step 1:法律领域预训练 80 | sh scripts/pre_training/run_train.sh 81 | 82 | # Step 2: 指令微调 83 | sh scripts/instruction_tuning/run_train.sh 84 | 85 | ``` 86 | ## 部署 87 | 88 | ### 环境要求 89 | 90 | 40G显存,只需1张A100/A800 或者 2张TITAN RTX... 91 | 92 | ### 部署命令 93 | 94 | ```sh 95 | # Gradio 界面 96 | python src/web/hanfei_app.py 97 | 98 | # RESTful api 99 | python src/serve/hanfei_serve.py 100 | ``` 101 | 102 | ## 项目参与者 103 | 本项目由来自于中科院深圳先进院得理法律人工智能联合实验室的何万伟、温嘉宝、张磊、程浩、秦博文、李云水、李之健,深圳市大数据研究院、港中文深圳的蒋峰、陈俊颖同学合作开发,指导教师为深圳市大数据研究院、港中文深圳的**王本友**助理教授和中科院深圳先进院的**杨敏**副研究员。 104 | 105 | ## 免责声明 106 | 本项目相关资源仅供学术研究之用,严禁用于商业用途。使用涉及第三方代码的部分时,请严格遵循相应的开源协议。模型生成的内容受模型计算、随机性和量化精度损失等因素影响,本项目无法对其准确性作出保证,本项目不承担任何法律责任,亦不对因使用相关资源和输出结果而可能产生的任何损失承担责任。 107 | 108 | ## 致谢 109 | 110 | 本项目参考了以下开源项目,在此对相关项目和研究开发人员表示感谢。 111 | 112 | Bloom: https://huggingface.co/bigscience/bloom 113 | 114 | Facebook LLaMA: https://github.com/facebookresearch/llama 115 | 116 | Stanford Alpaca: https://github.com/tatsu-lab/stanford_alpaca 117 | 118 | Self-instruct: https://github.com/yizhongw/self-instruct 119 | 120 | ## 引用 121 | 122 | 如果您使用了本项目的内容,或者认为本项目对您的研究有帮助,请引用本项目。 123 | 124 | ``` 125 | @misc{HanFei, 126 | author={Wanwei He and Jiabao Wen and Lei Zhang and Hao Cheng and Bowen Qin and Yunshui Li and Feng Jiang and Junying Chen and Benyou Wang and Min Yang}, 127 | title={HanFei-1.0}, 128 | year={2023}, 129 | publisher={GitHub}, 130 | journal={GitHub repository}, 131 | howpublished={\url{https://github.com/siat-nlp/HanFei}}, 132 | } 133 | ``` 134 | 135 | -------------------------------------------------------------------------------- /assets/images/Snipaste_2023-05-09_00-00-35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/assets/images/Snipaste_2023-05-09_00-00-35.png -------------------------------------------------------------------------------- /assets/images/Snipaste_2023-05-09_00-04-06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/assets/images/Snipaste_2023-05-09_00-04-06.png -------------------------------------------------------------------------------- /assets/images/Snipaste_2023-05-09_00-06-44.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/assets/images/Snipaste_2023-05-09_00-06-44.png -------------------------------------------------------------------------------- /assets/images/Snipaste_2023-05-09_00-18-19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/assets/images/Snipaste_2023-05-09_00-18-19.png -------------------------------------------------------------------------------- /assets/images/Snipaste_2023-05-09_00-24-15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/assets/images/Snipaste_2023-05-09_00-24-15.png -------------------------------------------------------------------------------- /assets/images/Snipaste_2023-05-09_00-51-19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/assets/images/Snipaste_2023-05-09_00-51-19.png -------------------------------------------------------------------------------- /assets/images/Snipaste_2023-05-09_01-03-22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/assets/images/Snipaste_2023-05-09_01-03-22.png -------------------------------------------------------------------------------- /assets/images/Snipaste_2023-05-09_12-07-31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/assets/images/Snipaste_2023-05-09_12-07-31.png -------------------------------------------------------------------------------- /assets/images/Snipaste_2023-05-09_12-17-42.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/assets/images/Snipaste_2023-05-09_12-17-42.png -------------------------------------------------------------------------------- /assets/images/Snipaste_2023-05-09_12-20-14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/assets/images/Snipaste_2023-05-09_12-20-14.png -------------------------------------------------------------------------------- /assets/images/Snipaste_2023-05-09_12-22-49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/assets/images/Snipaste_2023-05-09_12-22-49.png -------------------------------------------------------------------------------- /assets/images/total_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/assets/images/total_score.png -------------------------------------------------------------------------------- /merge_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | # from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict 5 | # from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict 6 | from zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict 7 | 8 | 9 | if __name__ == '__main__': 10 | """ 11 | 若lm_head参数为空,则在from_pretrained加载模型时需指定ignore_mismatched_sizes=True 12 | """ 13 | experiment_name = "bloomz-7b1-mt-sft-gpu8-1e5" 14 | epoch, global_step = 1, 3853 15 | filename = experiment_name + "-epoch={epoch:02d}-step={step}".format(epoch=epoch, step=global_step) 16 | 17 | root_dir = f"/HanFei/output/task/{experiment_name}" 18 | output_dir = f"{root_dir}/models/global_step{global_step}" 19 | if not os.path.exists(output_dir): 20 | os.makedirs(output_dir) 21 | 22 | huggingface_dir = os.path.join(root_dir, "best_tfmr") 23 | files = os.listdir(huggingface_dir) 24 | for file_name in files: 25 | if file_name == "pytorch_model.bin": 26 | continue 27 | src_path = os.path.join(huggingface_dir, file_name) 28 | dest_path = os.path.join(output_dir, file_name) 29 | if not os.path.exists(dest_path): 30 | shutil.copy2(src_path, dest_path) 31 | 32 | # lightning deepspeed has saved a directory instead of a file 33 | save_path = f"{root_dir}/checkpoints/{filename}.ckpt" 34 | output_path = os.path.join(output_dir, "pytorch_model.bin") 35 | convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path) 36 | -------------------------------------------------------------------------------- /scripts/instruction_tuning/run_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # Main run. 5 | python tasks/task/train.py \ 6 | --experiment_name=bloomz-7b1-mt-sft-gpu8-1e5 \ 7 | --training_stage=instruction_tuning \ 8 | --model_name=gpt \ 9 | --model_name_or_path=/model_name_or_path \ 10 | --tokenizer_name=/model_name_or_path \ 11 | --data_dir=/data_path \ 12 | --train_batch_size=1 \ 13 | --eval_batch_size=1 \ 14 | --accum_batches_args=16 \ 15 | --max_source_length=4096 \ 16 | --max_target_length=4096 \ 17 | --num_sanity_val_steps=0 \ 18 | --val_check_interval=0.1 \ 19 | --log_every_n_steps=10 \ 20 | --save_every_n_epochs=1 \ 21 | --save_top_k=-1 \ 22 | --max_epochs=3 \ 23 | --max_steps=-1 \ 24 | --lr_scheduler=cosine \ 25 | --learning_rate=1e-5 \ 26 | --warmup_steps=200 \ 27 | --weight_decay=0. \ 28 | --logger_name=WandbLogger \ 29 | --num_workers=8 30 | -------------------------------------------------------------------------------- /scripts/pre_training/run_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # Main run. 5 | python tasks/task/train.py \ 6 | --experiment_name=bloomz-7b1-mt-gpu8-1e5 \ 7 | --training_stage=pre_training \ 8 | --model_name=gpt \ 9 | --model_name_or_path=/model_name_or_path \ 10 | --tokenizer_name=/model_name_or_path \ 11 | --data_dir=/data_path \ 12 | --train_batch_size=1 \ 13 | --eval_batch_size=1 \ 14 | --accum_batches_args=16 \ 15 | --max_source_length=2048 \ 16 | --max_target_length=2048 \ 17 | --num_sanity_val_steps=0 \ 18 | --val_check_interval=0.1 \ 19 | --log_every_n_steps=10 \ 20 | --save_every_n_epochs=1 \ 21 | --save_top_k=-1 \ 22 | --max_epochs=2 \ 23 | --max_steps=-1 \ 24 | --lr_scheduler=linear \ 25 | --learning_rate=1e-5 \ 26 | --warmup_steps=1000 \ 27 | --weight_decay=0. \ 28 | --logger_name=WandbLogger \ 29 | --num_workers=8 30 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/src/__init__.py -------------------------------------------------------------------------------- /src/configuration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/src/configuration/__init__.py -------------------------------------------------------------------------------- /src/configuration/basic_argsparser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | 5 | from src.configuration.constants import BASE_DIR 6 | 7 | OUTPUT_DIR = f'{BASE_DIR}/output' 8 | 9 | 10 | def add_basic_args(parser: argparse.ArgumentParser = None): 11 | if parser is None: 12 | parser = argparse.ArgumentParser() 13 | # directories 14 | parser.add_argument('--output_dir', type=str, 15 | default=os.getenv('PT_OUTPUT_DIR', f'{BASE_DIR}/output/'), 16 | help='directory to save logs and checkpoints to') 17 | parser.add_argument('--data_dir', type=str, 18 | default=f'{BASE_DIR}/datasets/', 19 | help='directory with train, dev, test files') 20 | 21 | parser.add_argument('--desc', type=str, help="Description for the experiment.") 22 | parser.add_argument('--seed', type=int, default=42, help="Fixing random seeds helps reproduce the result.") 23 | parser.add_argument('--num_epochs', type=int, default=10, help="The max training epochs.") 24 | parser.add_argument('--experiment_name', type=str, 25 | default='plotmachine-default', 26 | help='name of this experiment will be included in output') 27 | 28 | # single parameter 29 | parser.add_argument('--show_progress', action='store_true', default=True, 30 | help='It will show the progress.') 31 | parser.add_argument('--no_gpu', action='store_true', default=False, 32 | help='Runing with gpu is banned.') 33 | parser.add_argument('--no_multi_gpus', action='store_true', default=False, 34 | help='Runing with multiple gpu is banned.') 35 | return parser 36 | -------------------------------------------------------------------------------- /src/configuration/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | @Notes: 5 | 6 | """ 7 | 8 | from pathlib import Path 9 | 10 | FILE_PATH = Path(__file__).absolute() 11 | BASE_DIR = FILE_PATH.parent.parent.parent 12 | 13 | from transformers import ( 14 | AutoModel, 15 | AutoModelForPreTraining, 16 | AutoModelForQuestionAnswering, 17 | AutoModelForSeq2SeqLM, 18 | AutoModelForSequenceClassification, 19 | AutoModelForTokenClassification, 20 | AutoModelWithLMHead, 21 | ) 22 | 23 | from transformers.optimization import ( 24 | get_cosine_schedule_with_warmup, 25 | get_cosine_with_hard_restarts_schedule_with_warmup, 26 | get_linear_schedule_with_warmup, 27 | get_polynomial_decay_schedule_with_warmup, 28 | ) 29 | 30 | from src.modules.task.data_loader import ( 31 | PretrainingDataset, 32 | SupervisedDataset, 33 | ) 34 | 35 | 36 | MODEL_CLASSES = { 37 | "base": AutoModel, 38 | "sequence-classification": AutoModelForSequenceClassification, 39 | "question-answering": AutoModelForQuestionAnswering, 40 | "pretraining": AutoModelForPreTraining, 41 | "token-classification": AutoModelForTokenClassification, 42 | "language-modeling": AutoModelWithLMHead, 43 | "summarization": AutoModelForSeq2SeqLM, 44 | "translation": AutoModelForSeq2SeqLM, 45 | } 46 | 47 | DATASET_CLASSES = { 48 | "pre_training": PretrainingDataset, 49 | "instruction_tuning": SupervisedDataset, 50 | } 51 | 52 | # update this and the import above to support new schedulers from transformers.optimization 53 | GET_SCHEDULER_FUNCS = { 54 | "linear": get_linear_schedule_with_warmup, 55 | "cosine": get_cosine_schedule_with_warmup, 56 | "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, 57 | "polynomial": get_polynomial_decay_schedule_with_warmup, 58 | } 59 | -------------------------------------------------------------------------------- /src/configuration/pl_argsparser.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | - pytorch_lightning documentation about trainer parameters: 5 | https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html? 6 | highlight=lighting_logs#pytorch_lightning.trainer.Trainer.params.flush_logs_every_n_steps 7 | - pl.Trainer.add_argparse_args: 8 | https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.utilities.argparse.html 9 | ?highlight=add_argparse_args#pytorch_lightning.utilities.argparse.add_argparse_args 10 | @Notes: 11 | parser.set_defaults(x="...") don't require parser.add_argument("--x", ...) 12 | """ 13 | 14 | import argparse 15 | import torch 16 | 17 | 18 | def set_device_args_for_pl_trainer(parser: argparse.ArgumentParser = None): 19 | if parser is None: 20 | parser = argparse.ArgumentParser() 21 | if torch.cuda.is_available(): 22 | # use gpu 23 | # If your machine has GPUs, it will use the GPU Accelerator for training 24 | parser.set_defaults(accelerator="gpu") 25 | # Number of GPUs to train on (int) 26 | parser.set_defaults(gpus=torch.cuda.device_count()) 27 | else: 28 | # use cpu 29 | parser.set_defaults(accelerator="cpu") 30 | parser.set_defaults(devices=1) 31 | parser.set_defaults(gpus=0) 32 | return parser 33 | 34 | 35 | def set_basic_args_for_pl_trainer(parser: argparse.ArgumentParser = None, output_dir=None): 36 | if parser is None: 37 | parser = argparse.ArgumentParser() 38 | if output_dir is None: 39 | raise ValueError("Must input the output_dir") 40 | # "Number of updates steps to accumulate before performing a backward/update pass." 41 | parser.add_argument('--accum_batches_args', type=str, help='set accumulate_grad_batches') 42 | # Stop training once this number of epochs is reached 43 | parser.set_defaults(max_epochs=10) 44 | # Force training for at least these many epochs 45 | # parser.set_defaults(min_epochs=2) 46 | # no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that 47 | parser.set_defaults(accumulate_grad_batches={5: 2, 10: 5}) 48 | # Automatically tries to find the largest batch size that fits into memory, before any training. 49 | # Trainer(auto_scale_batch_size="binsearch") run batch size scaling, result overrides hparams.batch_size 50 | parser.set_defaults(auto_scale_batch_size=None) 51 | # default used by the Trainer (no learning rate finder) 52 | parser.set_defaults(auto_lr_find=False) 53 | # How often within one training epoch to check the validation set. Can specify as float or int. 54 | parser.set_defaults(val_check_interval=0.5) 55 | # Default path for logs and weights when no logger or pytorch_lightning.callbacks.ModelCheckpoint callback passed. 56 | parser.set_defaults(default_root_dir=output_dir) 57 | # default used by Trainer, saves the most recent model to a single checkpoint after each epoch 58 | parser.set_defaults(enable_checkpointing=True) 59 | # Runs n if set to n (int) else 1 if set to True batch(es) of train, val and test to 60 | # find any bugs (ie: a sort of unit test). 61 | parser.set_defaults(fast_dev_run=False) 62 | # How often to add logging rows (does not write to disk) 63 | parser.set_defaults(log_every_n_steps=50) 64 | # Sanity check runs n batches of val before starting the training routine. 65 | parser.set_defaults(num_sanity_val_steps=1) 66 | # Whether to enable or disable the progress bar. Defaults to True. 67 | parser.set_defaults(enable_progress_bar=True) 68 | # Enable synchronization between batchnorm layers across all GPUs. 69 | parser.set_defaults(sync_batchnorm=True) 70 | # Directory of where to save weights if specified. 71 | parser.set_defaults(weights_save_path=output_dir) 72 | # Gradient clipping value 73 | parser.set_defaults(gradient_clip_val=None) 74 | # Uses this much data of the training set. If nonzero, will turn off validation. 75 | # If the training dataloaders have shuffle=True, Lightning will automatically disable it. 76 | parser.set_defaults(overfit_batches=0.0) 77 | return parser 78 | 79 | 80 | def set_speedup_args_for_pl_trainer(parser: argparse.ArgumentParser = None, amp_backend="apex", precision=16): 81 | if parser is None: 82 | parser = argparse.ArgumentParser() 83 | if amp_backend == "native": 84 | parser.set_defaults(precision=precision) 85 | elif amp_backend == "apex": 86 | # The optimization level to use (O1, O2, etc…) for 16-bit GPU precision (using NVIDIA apex under the hood). 87 | parser.set_defaults(amp_level='O2') 88 | # Lightning supports either double precision (64), full precision (32), or half precision (16) training. 89 | parser.set_defaults(precision=precision) 90 | else: 91 | raise NotImplementedError(f"amp_backend: {amp_backend}") 92 | # Use PyTorch AMP (‘native’) (available PyTorch 1.6+), or NVIDIA apex (‘apex’). 93 | parser.set_defaults(amp_backend=amp_backend) 94 | return parser 95 | 96 | 97 | def process_parsed_args_for_pl_trainer(args: argparse.Namespace): 98 | if args.accum_batches_args is not None: 99 | batches = eval(args.accum_batches_args) 100 | print(f"reset accumulate_grad_batches to {batches}") 101 | args.accumulate_grad_batches = batches 102 | # precision 103 | if args.accelerator == "cpu" and args.precision == 16: 104 | args.precision = "bf16" 105 | else: # gpu 106 | n_gpus = torch.cuda.device_count() 107 | if n_gpus > 1: # multiple gpus 108 | args.accelerator = "gpu" 109 | # args.strategy = "ddp" 110 | args.strategy = "deepspeed_stage_3" 111 | # args.strategy = "fsdp" 112 | # args.precision = 32 113 | -------------------------------------------------------------------------------- /src/configuration/task/config_args.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | - pytorch_lightning documentation about trainer parameters: 5 | https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html? 6 | highlight=lighting_logs#pytorch_lightning.trainer.Trainer.params.flush_logs_every_n_steps 7 | - pl.Trainer.add_argparse_args: 8 | https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.utilities.argparse.html 9 | ?highlight=add_argparse_args#pytorch_lightning.utilities.argparse.add_argparse_args 10 | @Notes: 11 | os.environ["TOKENIZERS_PARALLELISM"] = "false" because 12 | https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning 13 | """ 14 | 15 | import os 16 | import argparse 17 | 18 | import pytorch_lightning as pl 19 | from src.utils.string_utils import str2bool 20 | 21 | from src.configuration.constants import BASE_DIR 22 | from src.configuration.pl_argsparser import ( 23 | set_basic_args_for_pl_trainer, 24 | set_speedup_args_for_pl_trainer, 25 | set_device_args_for_pl_trainer, 26 | process_parsed_args_for_pl_trainer, 27 | ) 28 | 29 | EXPERIMENT_GROUP = "task" 30 | MODEL_NAME = "know_pre" 31 | OUTPUT_DIR = f'{BASE_DIR}/output/{EXPERIMENT_GROUP}' 32 | DATASETS_DIR = f'{BASE_DIR}/resources' 33 | RESOURCES_DIR = f'{BASE_DIR}/resources' 34 | DATA_NAME = "roc-stories" 35 | MODEL_NAME_OR_PATH = f'{RESOURCES_DIR}/external_models/t5-base' 36 | 37 | 38 | def add_customized_args(parser: argparse.ArgumentParser = None): 39 | if parser is None: 40 | parser = argparse.ArgumentParser() 41 | 42 | parser.add_argument( 43 | "--max_source_length", 44 | default=512, 45 | type=int, 46 | help="The maximum total input sequence length after tokenization. Sequences longer " 47 | "than this will be truncated, sequences shorter will be padded.", 48 | ) 49 | parser.add_argument( 50 | "--max_target_length", 51 | default=512, 52 | type=int, 53 | help="The maximum total input sequence length after tokenization. Sequences longer " 54 | "than this will be truncated, sequences shorter will be padded.", 55 | ) 56 | parser.add_argument("--freeze_encoder", action="store_true") 57 | parser.add_argument("--freeze_embeds", action="store_true") 58 | parser.add_argument("--max_tokens_per_batch", type=int, default=None) 59 | parser.add_argument("--logger_name", type=str, choices=["CSVLogger", "WandbLogger"], default="CSVLogger") 60 | parser.add_argument("--label_smoothing", type=float, default=0.0, required=False) 61 | parser.add_argument("--eval_beams", type=int, default=4, required=False) 62 | parser.add_argument("--num_beam_groups", type=int, default=0, required=False) 63 | parser.add_argument("--repetition_penalty", type=int, default=0, required=False) 64 | parser.add_argument("--num_return_sequences", type=int, default=20, required=False) 65 | parser.add_argument("--diversity_penalty", type=int, default=0, required=False) 66 | parser.add_argument("--r_drop_alpha", type=int, default=5, required=False) 67 | parser.add_argument( 68 | "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None] 69 | ) 70 | parser.add_argument( 71 | "--early_stopping_patience", 72 | type=int, 73 | default=-1, 74 | required=False, 75 | help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. " 76 | "So val_check_interval will effect it.", 77 | ) 78 | parser.add_argument("--fast_generate", action="store_true", default=False, 79 | help="for _generative_step") 80 | parser.add_argument("--remain_sp_tokens", action="store_true", default=False, 81 | help="remain special tokens in target and pred text (e.g. [EVENT_s])") 82 | parser.add_argument("--training_stage", type=str, choices=["pre_training", "instruction_tuning"], 83 | default="pre_training", help="Training stage") 84 | 85 | return parser 86 | 87 | 88 | def add_args_for_pytorch_lightning(parser: argparse.ArgumentParser = None): 89 | if parser is None: 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument( 92 | "--tokenizer_name", 93 | default=None, 94 | type=str, 95 | help="Pretrained tokenizer name or path if not the same as model_name", 96 | ) 97 | parser.add_argument( 98 | "--encoder_layerdrop", 99 | type=float, 100 | help="Encoder layer dropout probability (Optional). Goes into model.config", 101 | ) 102 | parser.add_argument( 103 | "--decoder_layerdrop", 104 | type=float, 105 | help="Decoder layer dropout probability (Optional). Goes into model.config", 106 | ) 107 | parser.add_argument( 108 | "--dropout", 109 | type=float, 110 | help="Dropout probability (Optional). Goes into model.config", 111 | ) 112 | parser.add_argument( 113 | "--attention_dropout", 114 | type=float, 115 | help="Attention dropout probability (Optional). Goes into model.config", 116 | ) 117 | parser.add_argument( 118 | "--lr_scheduler", 119 | default="linear", 120 | type=str, 121 | help="Learning rate scheduler", 122 | ) 123 | parser.add_argument("--weight_decay", default=0., type=float, help="Weight decay if we apply some.") 124 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 125 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 126 | parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") 127 | parser.add_argument("--optimizer_class", type=str, default="AdamW", help="optimizers: Adafactor|AdamW") 128 | 129 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 130 | 131 | parser.add_argument('--output_dir', type=str, 132 | default=os.getenv('PT_OUTPUT_DIR', OUTPUT_DIR), 133 | help='directory to save logs and checkpoints to') 134 | parser.add_argument('--data_dir', type=str, 135 | default=f'{DATASETS_DIR}', 136 | help='directory with train, dev, test files') 137 | parser.add_argument('--index_ratio', type=int, 138 | default=2, 139 | help='multi-task ratio') 140 | parser.add_argument('--qa_exist', type=int, choices=[0, 1], 141 | default=1, 142 | help='train on qa task') 143 | 144 | parser.add_argument('--id_len', type=int, 145 | default=8, 146 | help='length of id') 147 | parser.add_argument('--resources_dir', type=str, 148 | default=f'{RESOURCES_DIR}', 149 | help='directory with of resources, including pretrained off-line models.') 150 | parser.add_argument("--experiment_name", default=f"{MODEL_NAME}-{DATA_NAME}", 151 | type=str, help="the name of the experiment.") 152 | parser.add_argument("--model_name", default=f"{MODEL_NAME}", 153 | type=str, help="the name of the model used.") 154 | parser.add_argument("--model_name_or_path", 155 | default=MODEL_NAME_OR_PATH, 156 | type=str, 157 | help="Path to pretrained model or model identifier from huggingface.co/models.") 158 | parser.add_argument("--learning_rate", type=float, default=5e-5, help="learning rate.") 159 | parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") 160 | parser.add_argument("--n_val", type=int, default=1000, required=False, help="# examples. -1 means use all.") 161 | parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.") 162 | parser.add_argument("--overwrite_output_dir", action="store_true", default=True, help="overwrite_output_dir") 163 | parser.add_argument("--train_batch_size", default=10, type=int, help="train_batch_size.") 164 | parser.add_argument("--eval_batch_size", default=10, type=int, help="eval_batch_size.") 165 | 166 | # ############################### checkpoint settings #################################### 167 | parser.add_argument("--save_top_k", default=-1, type=int, 168 | help="The best k models according to the quantity monitored will be saved.") 169 | parser.add_argument("--save_every_n_epochs", default=None, type=int, 170 | help="Save on every n epochs") 171 | parser.add_argument("--save_every_n_steps", default=None, type=int, 172 | help="Save on every n steps") 173 | parser.add_argument("--every_n_val_epochs", default=1, type=int, 174 | help="every_n_val_epochs.") 175 | parser.add_argument("--ckpt_verbose", type=str2bool, default=False, 176 | help="verbosity mode. True / False") 177 | 178 | return parser 179 | 180 | 181 | def parse_args_for_config(parser: argparse.ArgumentParser = None): 182 | if parser is None: 183 | parser = argparse.ArgumentParser() 184 | 185 | parser = pl.Trainer.add_argparse_args(parser) # Extends existing argparse by default attributes for cls. 186 | # set defaults for args of pl.trainer 187 | parser = set_basic_args_for_pl_trainer(parser, output_dir=OUTPUT_DIR) 188 | parser = set_speedup_args_for_pl_trainer(parser, amp_backend="native", precision=16) 189 | parser = set_device_args_for_pl_trainer(parser) 190 | 191 | # customized 192 | parser = add_args_for_pytorch_lightning(parser) 193 | parser = add_customized_args(parser) 194 | args = parser.parse_args() 195 | 196 | # process parsed args 197 | process_parsed_args_for_pl_trainer(args) 198 | return args 199 | 200 | 201 | if __name__ == '__main__': 202 | import os 203 | 204 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 205 | args = parse_args_for_config() 206 | -------------------------------------------------------------------------------- /src/data_processing/gen_instruction_multi.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import time 3 | import json 4 | import os 5 | import random 6 | import re 7 | from multiprocessing import Pool 8 | import multiprocessing 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | import rich.progress 11 | 12 | from tqdm import tqdm 13 | import utils 14 | 15 | import openpyxl 16 | 17 | random.seed(42) 18 | 19 | def encode_prompt(prompt_instructions, prompt): 20 | prompt = prompt + "\n" 21 | 22 | idx = 1 23 | for _, task_dict in enumerate(prompt_instructions): 24 | (instruction, input, output) = task_dict["instruction"], task_dict["input"], task_dict["output"] 25 | instruction = re.sub(r"\s+", " ", instruction).strip().rstrip(":") 26 | input = "" if input == "" else input 27 | prompt += f"###\n" 28 | prompt += f"[{idx}]. 指令:{instruction}\n" 29 | prompt += f"[{idx}]. 输入:\n{input}\n" 30 | prompt += f"[{idx}]. 输出:\n{output}\n" 31 | idx += 1 32 | prompt += f"###\n" 33 | prompt += f"[{idx}]. 指令:" 34 | return prompt 35 | 36 | def merge_data(): 37 | # 列出所有的文件 38 | files = os.listdir('./../data/zh_law_instruction/') 39 | print(files) 40 | datas = [] 41 | for file in tqdm(files): 42 | if file.endswith('.json'): 43 | with open('./../data/zh_law_instruction/' + file, 'r', encoding='utf-8') as f: 44 | lines = f.readlines() 45 | for line in tqdm(lines): 46 | data = json.loads(line) 47 | datas.append(data) 48 | with open('./../data/zh_law_instruction/raw_zh_law_instruction.json', 'w', encoding='utf-8') as f: 49 | for data in tqdm(datas): 50 | f.write(json.dumps(data, ensure_ascii=False) + '\n') 51 | 52 | def post_process_gpt3_response(num_prompt_instructions, response): 53 | if response is None: 54 | return [] 55 | raw_instructions = f"[{num_prompt_instructions+1}]. 指令:" + response.message.content 56 | raw_instructions = re.split("###", raw_instructions) 57 | 58 | instructions = [] 59 | for idx, inst in enumerate(raw_instructions): 60 | # if the decoding stops due to length, the last example is likely truncated so we discard it 61 | if idx == len(raw_instructions) - 1 and response["finish_reason"] == "length": 62 | continue 63 | idx = idx + num_prompt_instructions + 1 64 | splitted_data = re.split(f"\[{idx}\]\.\s+(指令|输入|输出):", inst) 65 | 66 | # print(splitted_data) 67 | if len(splitted_data) != 7: 68 | continue 69 | else: 70 | inst = splitted_data[2].strip() 71 | input = splitted_data[4].strip() 72 | input = "" if input.lower() == "" else input 73 | output = splitted_data[6].strip() 74 | instructions.append({"instruction": inst, "input": input, "output": output}) 75 | return instructions 76 | 77 | 78 | def read_xlsx(path): 79 | book = openpyxl.load_workbook(path) 80 | sheet = book.active 81 | data = [] 82 | for row in range(2, 204): 83 | instruction = sheet.cell(row=row, column=2).value 84 | input = sheet.cell(row=row, column=3).value 85 | output = sheet.cell(row=row, column=4).value 86 | if instruction != "": 87 | data.append({"instruction": instruction, "input": input, "output": output}) 88 | return data 89 | 90 | def genInstruction(cpu_id, num_instructions_to_generate, seed_instruction_data, num_prompt_instructions, output_dir): 91 | 92 | prompt = open("./prompt.txt", encoding='utf-8').read() 93 | random.seed(1024+cpu_id) 94 | for i in tqdm(range(num_instructions_to_generate)): 95 | # 随机选取num_prompt_instructions个seed instruction 96 | prompt_instructions = random.sample(seed_instruction_data, num_prompt_instructions) 97 | inputs = encode_prompt(prompt_instructions, prompt) 98 | batch_inputs = [inputs] 99 | decoding_args = utils.OpenAIDecodingArguments() 100 | request_start = time.time() 101 | results = utils.openai_completion( # 一个输入 102 | prompts=batch_inputs, 103 | batch_size=1, 104 | decoding_args=decoding_args, 105 | # logit_bias={"50256": -100}, # prevent the <|endoftext|> token from being generated 106 | ) 107 | 108 | request_duration = time.time() - request_start 109 | 110 | process_start = time.time() 111 | instruction_data = [] 112 | for result in results: 113 | new_instructions = post_process_gpt3_response(num_prompt_instructions, result) 114 | instruction_data += new_instructions 115 | # print(instruction_data) 116 | 117 | process_duration = time.time() - process_start 118 | print(f"Request {i} took {request_duration:.2f}s, processing took {process_duration:.2f}s") 119 | # print(f"Generated {total} instructions, kept {keep} instructions") 120 | # 写到json文件,一行一个数据 121 | with open(os.path.join(output_dir, f"{cpu_id+20}.json"), mode='a', encoding='utf-8') as f: 122 | if len(instruction_data) > 0: 123 | f.write('\n'.join(json.dumps(instruction, ensure_ascii=False) for instruction in instruction_data) +'\n') 124 | 125 | # utils.jdump(instruction_data, os.path.join(output_dir, "regen.json")) 126 | 127 | def generate_instruction_following_data( 128 | output_dir="../data/", 129 | seed_tasks_path="../data/seed_task.xlsx", 130 | num_instructions_to_generate=4000, 131 | num_prompt_instructions=3, 132 | request_batch_size=1, 133 | num_cpus=12, 134 | ): 135 | 136 | seed_instruction_data = read_xlsx(seed_tasks_path) 137 | 138 | print(f"Loaded {len(seed_instruction_data)} seed tasks") 139 | 140 | for i in range(num_cpus): 141 | p = multiprocessing.Process(target=genInstruction, args=(i, num_instructions_to_generate, seed_instruction_data, num_prompt_instructions, output_dir)) 142 | p.start() 143 | 144 | if __name__ == "__main__": 145 | generate_instruction_following_data() 146 | 147 | merge_data() -------------------------------------------------------------------------------- /src/data_processing/generate_instruction.py: -------------------------------------------------------------------------------- 1 | """ 2 | batch_selfinstruct_generate.py 3 | 4 | run: 5 | python -m generate_instruction generate_instruction_following_data \ 6 | --output_dir ./ \ 7 | --num_instructions_to_generate 10 \ 8 | --model_name="text-davinci-003" \ 9 | """ 10 | import time 11 | import json 12 | import os 13 | import random 14 | import re 15 | import string 16 | from functools import partial 17 | from multiprocessing import Pool 18 | 19 | import numpy as np 20 | import tqdm 21 | from rouge_score import rouge_scorer 22 | import utils 23 | 24 | import openpyxl 25 | 26 | def encode_prompt(prompt_instructions): 27 | """Encode multiple prompt instructions into a single string.""" 28 | prompt = open("./prompt.txt", encoding='utf-8').read() + "\n" 29 | idx = 1 30 | for _, task_dict in enumerate(prompt_instructions): 31 | (instruction, input, output) = task_dict["instruction"], task_dict["input"], task_dict["output"] 32 | instruction = re.sub(r"\s+", " ", instruction).strip().rstrip(":") 33 | input = "" if input == "" else input 34 | prompt += f"###\n" 35 | prompt += f"[{idx}]. 指令:{instruction}\n" 36 | prompt += f"[{idx}]. 输入:\n{input}\n" 37 | prompt += f"[{idx}]. 输出:\n{output}\n" 38 | idx += 1 39 | prompt += f"###\n" 40 | prompt += f"[{idx}]. 指令:" 41 | return prompt 42 | 43 | 44 | def post_process_gpt3_response(num_prompt_instructions, response): 45 | if response is None: 46 | return [] 47 | raw_instructions = f"[{num_prompt_instructions+1}]. 指令:" + response.message.content 48 | raw_instructions = re.split("###", raw_instructions) 49 | 50 | instructions = [] 51 | for idx, inst in enumerate(raw_instructions): 52 | # if the decoding stops due to length, the last example is likely truncated so we discard it 53 | if idx == len(raw_instructions) - 1 and response["finish_reason"] == "length": 54 | continue 55 | idx = idx + num_prompt_instructions + 1 56 | splitted_data = re.split(f"\[{idx}\]\.\s+(指令|输入|输出):", inst) 57 | 58 | # print(splitted_data) 59 | if len(splitted_data) != 7: 60 | continue 61 | else: 62 | inst = splitted_data[2].strip() 63 | input = splitted_data[4].strip() 64 | input = "" if input.lower() == "" else input 65 | output = splitted_data[6].strip() 66 | # filter out too short or too long instructions 67 | # if len(inst.split()) <= 3 or len(inst.split()) > 100: 68 | # continue 69 | # # filter based on keywords that are not suitable for language models. 70 | # blacklist = [ 71 | # "image", 72 | # "images", 73 | # "graph", 74 | # "graphs", 75 | # "picture", 76 | # "pictures", 77 | # "file", 78 | # "files", 79 | # "map", 80 | # "maps", 81 | # "draw", 82 | # "plot", 83 | # "go to", 84 | # "video", 85 | # "audio", 86 | # "music", 87 | # "flowchart", 88 | # "diagram", 89 | # ] 90 | # blacklist += [] 91 | # if any(find_word_in_string(word, inst) for word in blacklist): 92 | # continue 93 | # We found that the model tends to add "write a program" to some existing instructions, which lead to a lot of such instructions. 94 | # And it's a bit comfusing whether the model need to write a program or directly output the result. 95 | # Here we filter them out. 96 | # Note this is not a comprehensive filtering for all programming instructions. 97 | # if inst.startswith("Write a program"): 98 | # continue 99 | # filter those starting with punctuation 100 | # if inst[0] in string.punctuation: 101 | # continue 102 | # filter those starting with non-english character 103 | # if not inst[0].isascii(): 104 | # continue 105 | instructions.append({"instruction": inst, "input": input, "output": output}) 106 | return instructions 107 | 108 | 109 | # def find_word_in_string(w, s): 110 | # return re.compile(r"\b({0})\b".format(w), flags=re.IGNORECASE).search(s) 111 | 112 | 113 | def read_xlsx(path): 114 | book = openpyxl.load_workbook(path) 115 | sheet = book.active 116 | data = [] 117 | for row in range(2, 203): 118 | instruction = sheet.cell(row=row, column=2).value 119 | input = sheet.cell(row=row, column=3).value 120 | output = sheet.cell(row=row, column=4).value 121 | data.append({"instruction": instruction, "input": input, "output": output}) 122 | 123 | return data 124 | 125 | def generate_instruction_following_data( 126 | output_dir="../data/", 127 | seed_tasks_path="../data/seed_task.xlsx", 128 | num_instructions_to_generate=100000, 129 | num_prompt_instructions=3, 130 | request_batch_size=1, 131 | num_cpus=16, 132 | ): 133 | 134 | seed_instruction_data = read_xlsx(seed_tasks_path) 135 | 136 | print(f"Loaded {len(seed_instruction_data)} seed tasks") 137 | 138 | # print(seed_instruction_data[0]) 139 | 140 | # exit(0) 141 | 142 | # seed_instruction_data = [ 143 | # {"instruction": t["instruction"], "input": t["input"], "output": t["output"]} 144 | # for t in seed_tasks 145 | # ] 146 | # print(f"Loaded {len(seed_instruction_data)} human-written seed instructions") 147 | 148 | os.makedirs(output_dir, exist_ok=True) 149 | request_idx = 0 150 | # load the LM-generated instructions 151 | 152 | # if os.path.exists(os.path.join(output_dir, "regen.json")): 153 | # machine_instruction_data = utils.jload(os.path.join(output_dir, "regen.json")) 154 | # print(f"Loaded {len(machine_instruction_data)} machine-generated instructions") 155 | 156 | # similarities = {} 157 | scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False) 158 | 159 | # now let's generate new instructions! 160 | progress_bar = tqdm.tqdm(total=num_instructions_to_generate) 161 | # if machine_instruction_data: 162 | # progress_bar.update(len(machine_instruction_data)) 163 | 164 | # first we tokenize all the seed instructions and generated machine instructions 165 | all_instructions = [d["instruction"] for d in seed_instruction_data] 166 | all_instruction_tokens = [scorer._tokenizer.tokenize(inst) for inst in all_instructions] 167 | cnt = 0 168 | 169 | while cnt < num_instructions_to_generate: 170 | request_idx += 1 171 | batch_inputs = [] 172 | for _ in range(request_batch_size): 173 | # 随机选取num_prompt_instructions个seed instruction 174 | prompt_instructions = random.sample(seed_instruction_data, num_prompt_instructions) 175 | prompt = encode_prompt(prompt_instructions) 176 | batch_inputs.append(prompt) 177 | decoding_args = utils.OpenAIDecodingArguments() 178 | request_start = time.time() 179 | results = utils.openai_completion( # 一个输入 180 | prompts=batch_inputs, 181 | batch_size=request_batch_size, 182 | decoding_args=decoding_args, 183 | # logit_bias={"50256": -100}, # prevent the <|endoftext|> token from being generated 184 | ) 185 | 186 | request_duration = time.time() - request_start 187 | 188 | process_start = time.time() 189 | instruction_data = [] 190 | for result in results: 191 | new_instructions = post_process_gpt3_response(num_prompt_instructions, result) 192 | cnt += len(new_instructions) 193 | instruction_data += new_instructions 194 | # print(instruction_data) 195 | 196 | # total = len(instruction_data) 197 | # keep = 0 198 | # for instruction_data_entry in instruction_data: 199 | # # computing similarity with the pre-tokenzied instructions 200 | # new_instruction_tokens = scorer._tokenizer.tokenize(instruction_data_entry["instruction"]) 201 | # with Pool(num_cpus) as p: 202 | # rouge_scores = p.map( 203 | # partial(rouge_scorer._score_lcs, new_instruction_tokens), 204 | # all_instruction_tokens, 205 | # ) 206 | # rouge_scores = [score.fmeasure for score in rouge_scores] 207 | # most_similar_instructions = { 208 | # all_instructions[i]: rouge_scores[i] for i in np.argsort(rouge_scores)[-10:][::-1] 209 | # } 210 | # if max(rouge_scores) > 0.7: 211 | # continue 212 | # else: 213 | # keep += 1 214 | # instruction_data_entry["most_similar_instructions"] = most_similar_instructions 215 | # instruction_data_entry["avg_similarity_score"] = float(np.mean(rouge_scores)) 216 | # machine_instruction_data.append(instruction_data_entry) 217 | # all_instructions.append(instruction_data_entry["instruction"]) 218 | # all_instruction_tokens.append(new_instruction_tokens) 219 | progress_bar.update(1) 220 | process_duration = time.time() - process_start 221 | print(f"Request {request_idx} took {request_duration:.2f}s, processing took {process_duration:.2f}s") 222 | # print(f"Generated {total} instructions, kept {keep} instructions") 223 | # 写到json文件,一行一个数据 224 | with open(os.path.join(output_dir, "regen.json"), mode='a', encoding='utf-8') as f: 225 | if len(instruction_data) > 0: 226 | f.write('\n'.join(json.dumps(instruction, ensure_ascii=False) for instruction in instruction_data) +'\n') 227 | 228 | # utils.jdump(instruction_data, os.path.join(output_dir, "regen.json")) 229 | 230 | 231 | # def main(task, **kwargs): 232 | # globals()[task](**kwargs) 233 | 234 | 235 | if __name__ == "__main__": 236 | generate_instruction_following_data() -------------------------------------------------------------------------------- /src/data_processing/prompt.txt: -------------------------------------------------------------------------------- 1 | 你被要求编写6个不同的中国法律任务指令,前3个已经写好,你只需要写出后面3个。 2 | 以下是要求: 3 | 1.每条指令用不同的动词,以最大限度地提高多样性。 4 | 2.指令使用的语言也应该是多样的。例如,你可以使用疑问句或祈使句。 5 | 3.指令的类型应该多样化。列表应该包括不同类型的任务。 6 | 4.你写的指令是一些法律相关的任务。 7 | 5.指令需要用中文来编写。 8 | 6.不是所有指令都需要有输入。如果指令已经很清晰,那么就不需要输入,输入部分填""即可。 9 | 7.指令的长度应为1至2句话,祈使句或疑问句都可以。 10 | 8.你需要为该指令生成适当的输入。输入字段应包含为该指令提供的特定示例。它应该包含真实的数据,而不应该包含简单的占位符。输入内容应该丰富,使指令具有挑战性,但最好不超过500字。 11 | 9.指令和输入组成一个问题,输出是对这个问题的回答,确保输出不超过500字。 12 | 10.你只需要写出指令、输入和输出三部分。 13 | 14 | 6个任务指令: -------------------------------------------------------------------------------- /src/data_processing/test.py: -------------------------------------------------------------------------------- 1 | 2 | # import re 3 | # from transformers import AutoTokenizer, AutoModelForCausalLM 4 | 5 | # # line = ' 你好 \n 你好 ' 6 | # # print(line) 7 | # # line = re.sub(r'\s+', '', line) 8 | # # print(line) 9 | 10 | 11 | # def text2token(text_list): 12 | # tokenizer = AutoTokenizer.from_pretrained("/mntcephfs/data/med/wjb/LawChatGPT/pretrain_model") 13 | 14 | # tokens = [] 15 | # for text in text_list: 16 | # text = text + tokenizer.eos_token 17 | # token = tokenizer.encode(text) 18 | # tokens.extend(token) 19 | 20 | # print(len(tokens)) 21 | 22 | # # print(tokens) 23 | # # print(tokenizer.decode(tokens)) 24 | 25 | 26 | # text_list = [ 27 | # "审判长伍辉耘\n审判员徐伟\n审判员缪崇民\n二〇一三年三月二十七日\n书记员王佳", 28 | # "衢州市衢江区人口和计划生育局、衢州市衢江区人口和计划生育局依据已经发生法与邵红楼、汤明仙非诉执行审查裁定书", 29 | # "上诉人深圳市规划和国土资源委员会因与被上诉人深圳市投资控股有限公司房地产权抵押登记纠纷一案,不服深圳市罗湖区人民法院(2011)深罗法行初字第8号行政判决,向本院提起上诉。本院依法组成合议庭进行了审理,本案现已审理终结。", 30 | # "上诉人深圳市规划和国土资源委员会因与被上诉人深圳市投资控股有限公司房地产权抵押登记纠纷一案,不服深圳市罗湖区人民法院(2011)深罗法行初字第8号行政判决,向本院提起上诉。本院依法组成合议庭进行了审理,本案现已审理终结。", 31 | # "上诉人深圳市规划和国土资源委员会因与被上诉人深圳市投资控股有限公司房地产权抵押登记纠纷一案,不服深圳市罗湖区人民法院(2011)深罗法行初字第8号行政判决,向本院提起上诉。本院依法组成合议庭进行了审理,本案现已审理终结。", 32 | # "上诉人深圳市规划和国土资源委员会因与被上诉人深圳市投资控股有限公司房地产权抵押登记纠纷一案,不服深圳市罗湖区人民法院(2011)深罗法行初字第8号行政判决,向本院提起上诉。本院依法组成合议庭进行了审理,本案现已审理终结。", 33 | # "上诉人深圳市规划和国土资源委员会因与被上诉人深圳市投资控股有限公司房地产权抵押登记纠纷一案,不服深圳市罗湖区人民法院(2011)深罗法行初字第8号行政判决,向本院提起上诉。本院依法组成合议庭进行了审理,本案现已审理终结。", 34 | # "上诉人深圳市规划和国土资源委员会因与被上诉人深圳市投资控股有限公司房地产权抵押登记纠纷一案,不服深圳市罗湖区人民法院(2011)深罗法行初字第8号行政判决,向本院提起上诉。本院依法组成合议庭进行了审理,本案现已审理终结。", 35 | # "上诉人深圳市规划和国土资源委员会因与被上诉人深圳市投资控股有限公司房地产权抵押登记纠纷一案,不服深圳市罗湖区人民法院(2011)深罗法行初字第8号行政判决,向本院提起上诉。本院依法组成合议庭进行了审理,本案现已审理终结。", 36 | # "上诉人深圳市规划和国土资源委员会因与被上诉人深圳市投资控股有限公司房地产权抵押登记纠纷一案,不服深圳市罗湖区人民法院(2011)深罗法行初字第8号行政判决,向本院提起上诉。本院依法组成合议庭进行了审理,本案现已审理终结。", 37 | # "上诉人深圳市规划和国土资源委员会因与被上诉人深圳市投资控股有限公司房地产权抵押登记纠纷一案,不服深圳市罗湖区人民法院(2011)深罗法行初字第8号行政判决,向本院提起上诉。本院依法组成合议庭进行了审理,本案现已审理终结。" 38 | # ] 39 | # text2token(text_list) 40 | 41 | # content = re.sub(r'<[^>]+>', '', "aaa") 42 | 43 | # # content = re.findall(r'<[^>]+>([^<]+)]+>', ) 44 | 45 | # print(content) 46 | 47 | 48 | # res = [] 49 | 50 | 51 | # import os 52 | # def gci(filepath): 53 | # files = os.listdir(filepath) 54 | # for fi in files: 55 | # fi_d = os.path.join(filepath, fi) 56 | # if os.path.isdir(fi_d): 57 | # gci(fi_d) 58 | # else: 59 | # res.append(os.path.join(filepath, fi_d)) 60 | 61 | # gci('/mntcephfs/data/med/wjb/LawChatGPT/data/qa') 62 | 63 | # print(res) 64 | # print(len(res)) 65 | 66 | 67 | # list_1 = [ 68 | # [1,2,3], 69 | # [5,43,62,4], 70 | # [2,2,3], 71 | # ] 72 | 73 | # import random 74 | # random.shuffle(list_1) 75 | # print(list_1) 76 | 77 | print('\ndfdsafsa \n dfa\n'.strip()) -------------------------------------------------------------------------------- /src/data_processing/utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import logging 3 | import math 4 | import os 5 | import io 6 | import sys 7 | import time 8 | import json 9 | from typing import Optional, Sequence, Union 10 | 11 | import openai 12 | import tqdm 13 | from openai import openai_object 14 | import copy 15 | 16 | 17 | # openai.api_key = "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" 18 | 19 | # openai.api_key = "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" # wo 20 | 21 | openai.api_key = "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" # qin 22 | 23 | 24 | # model_list = openai.Model.list() 25 | 26 | # datas = model_list['data'] 27 | 28 | # models = [data['id'] for data in datas] 29 | 30 | # print(models) 31 | 32 | 33 | StrOrOpenAIObject = Union[str, openai_object.OpenAIObject] 34 | 35 | # openai_org = os.getenv("OPENAI_ORG") 36 | # if openai_org is not None: 37 | # openai.organization = openai_org 38 | # logging.warning(f"Switching to organization: {openai_org} for OAI API key.") 39 | 40 | 41 | @dataclasses.dataclass 42 | class OpenAIDecodingArguments(object): 43 | model : str = "gpt-3.5-turbo" 44 | # max_tokens: int = 2048 45 | temperature: float = 0.8 46 | top_p: float = 1.0 47 | n: int = 1 48 | stream: bool = False 49 | # stop: Optional[Sequence[str]] = None 50 | presence_penalty: float = 0.0 51 | frequency_penalty: float = 0.0 52 | # suffix: Optional[str] = None 53 | # logprobs: Optional[int] = None 54 | # echo: bool = False 55 | logit_bias = {"50256": -100} 56 | stop=["\n[6].", "[6]."] 57 | 58 | def openai_completion( 59 | prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]], 60 | decoding_args: OpenAIDecodingArguments, 61 | sleep_time=2, 62 | batch_size=1, 63 | max_instances=sys.maxsize, 64 | max_batches=sys.maxsize, 65 | return_text=False, 66 | ) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]: 67 | 68 | is_single_prompt = isinstance(prompts, (str, dict)) 69 | if is_single_prompt: 70 | prompts = [prompts] 71 | 72 | if max_batches < sys.maxsize: 73 | logging.warning( 74 | "`max_batches` will be deprecated in the future, please use `max_instances` instead." 75 | "Setting `max_instances` to `max_batches * batch_size` for now." 76 | ) 77 | max_instances = max_batches * batch_size 78 | 79 | prompts = prompts[:max_instances] 80 | num_prompts = len(prompts) 81 | prompt_batches = [ 82 | prompts[batch_id * batch_size : (batch_id + 1) * batch_size] 83 | for batch_id in range(int(math.ceil(num_prompts / batch_size))) 84 | ] 85 | 86 | completions = [] 87 | for prompt_batch in prompt_batches: 88 | batch_decoding_args = copy.deepcopy(decoding_args) # cloning the decoding_args 89 | while True: 90 | try: 91 | shared_kwargs = dict( 92 | **batch_decoding_args.__dict__, 93 | ) 94 | # print(shared_kwargs) 95 | messages = [ 96 | {"role":"system","content": "你是一个法律助手。"}, 97 | {"role":"user","content": prompt_batch[0]} 98 | ] 99 | completion_batch = openai.ChatCompletion.create(messages=messages, **shared_kwargs) 100 | 101 | # print(prompt_batch) 102 | # print(len(prompt_batch[0])) 103 | # completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs) 104 | choices = completion_batch.choices 105 | # text = choices[0].text 106 | # print(text) 107 | # print(len(text)) 108 | # completions.append(text) 109 | for choice in choices: 110 | # print(choice.message.content) 111 | choice["total_tokens"] = completion_batch.usage.total_tokens 112 | completions.extend(choices) 113 | break 114 | except openai.error.OpenAIError as e: 115 | print(e) 116 | # print(len(prompt_batch[0])) 117 | # max_len = int(max_len * 0.9) 118 | # logging.warning(f"OpenAIError: {e}.") 119 | # if "Please reduce your prompt" in str(e): 120 | # batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.9) 121 | # logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...") 122 | # else: 123 | # logging.warning("Hit request rate limit; retrying...") 124 | # time.sleep(sleep_time) # Annoying rate limit on requests. 125 | 126 | if return_text: 127 | completions = [completion.text for completion in completions] 128 | if decoding_args.n > 1: 129 | # make completions a nested list, where each entry is a consecutive decoding_args.n of original entries. 130 | completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)] 131 | if is_single_prompt: 132 | # Return non-tuple if only 1 input and 1 generation. 133 | (completions,) = completions 134 | return completions 135 | 136 | 137 | def _make_w_io_base(f, mode: str): 138 | if not isinstance(f, io.IOBase): 139 | f_dirname = os.path.dirname(f) 140 | if f_dirname != "": 141 | os.makedirs(f_dirname, exist_ok=True) 142 | f = open(f, mode=mode, encoding="utf-8") 143 | return f 144 | 145 | 146 | def _make_r_io_base(f, mode: str): 147 | if not isinstance(f, io.IOBase): 148 | f = open(f, mode=mode) 149 | return f 150 | 151 | 152 | def jdump(obj, f, mode="a", indent=4, default=str): 153 | """Dump a str or dictionary to a file in json format. 154 | 155 | Args: 156 | obj: An object to be written. 157 | f: A string path to the location on disk. 158 | mode: Mode for opening the file. 159 | indent: Indent for storing json dictionaries. 160 | default: A function to handle non-serializable entries; defaults to `str`. 161 | """ 162 | f = _make_w_io_base(f, mode) 163 | if isinstance(obj, (dict, list)): 164 | json.dump(obj, f, indent=indent, default=default) 165 | elif isinstance(obj, str): 166 | f.write(obj) 167 | else: 168 | raise ValueError(f"Unexpected type: {type(obj)}") 169 | f.close() 170 | 171 | 172 | def jload(f, mode="r"): 173 | """Load a .json file into a dictionary.""" 174 | f = _make_r_io_base(f, mode) 175 | jdict = json.load(f) 176 | f.close() 177 | return jdict 178 | -------------------------------------------------------------------------------- /src/flash_bloom/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/src/flash_bloom/__init__.py -------------------------------------------------------------------------------- /src/flash_bloom/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from flash_attn.flash_attn_interface import flash_attn_unpadded_func 4 | 5 | 6 | def _flash_attn(q, k, v, mask=None, bias=None): 7 | batch_dims = q.shape[:-3] 8 | no_heads, n, c = q.shape[-3:] 9 | dtype = q.dtype 10 | 11 | k_no_heads, k_n, k_c = k.shape[-3:] 12 | 13 | # [*, B, N, H, C] 14 | q = q.transpose(-2, -3) 15 | k = k.transpose(-2, -3) 16 | v = v.transpose(-2, -3) 17 | 18 | # [B_flat, N, H, C] 19 | q = q.reshape(-1, *q.shape[-3:]) 20 | k = k.reshape(-1, *k.shape[-3:]) 21 | v = v.reshape(-1, *v.shape[-3:]) 22 | 23 | # Flattened batch size 24 | batch_size = q.shape[0] 25 | k_batch_size = k.shape[0] 26 | 27 | # [B_flat * N, H, C] 28 | q = q.reshape(-1, *q.shape[-2:]) 29 | k = k.reshape(-1, *k.shape[-2:]) 30 | v = v.reshape(-1, *v.shape[-2:]) 31 | 32 | q_max_s = n 33 | q_cu_seqlens = torch.arange( 34 | 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device 35 | ) 36 | 37 | k_max_s = k_n 38 | k_cu_seqlens = torch.arange( 39 | 0, (k_batch_size + 1) * k_n, step=k_n, dtype=torch.int32, device=k.device 40 | ) 41 | 42 | if mask is not None: 43 | mask_heads, tgt_len, src_len = mask.shape[-3:] 44 | mask = mask.reshape(-1, mask_heads, tgt_len, src_len).contiguous() 45 | 46 | if bias is not None: 47 | bias_heads, tgt_len, src_len = bias.shape[-3:] 48 | bias = bias.reshape(-1, bias_heads, tgt_len, src_len).contiguous() 49 | 50 | out = flash_attn_unpadded_func( 51 | q, 52 | k, 53 | v, 54 | q_cu_seqlens, 55 | k_cu_seqlens, 56 | q_max_s, 57 | k_max_s, 58 | attn_mask=mask, 59 | attn_bias=bias, 60 | dropout_p=0., 61 | softmax_scale=1., # q has been scaled already 62 | ) 63 | 64 | # [*, B, N, H, C] 65 | out = out.reshape(*batch_dims, n, no_heads, c) 66 | return out 67 | -------------------------------------------------------------------------------- /src/flash_bloom/bert_padding.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from einops import rearrange, repeat 7 | 8 | 9 | class IndexFirstAxis(torch.autograd.Function): 10 | 11 | @staticmethod 12 | def forward(ctx, input, indices): 13 | ctx.save_for_backward(indices) 14 | assert input.ndim >= 2 15 | ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] 16 | second_dim = other_shape.numel() 17 | # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. 18 | # return input[indices] 19 | return torch.gather(rearrange(input, 'b ... -> b (...)'), 0, 20 | repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape) 21 | 22 | @staticmethod 23 | def backward(ctx, grad_output): 24 | indices, = ctx.saved_tensors 25 | assert grad_output.ndim >= 2 26 | other_shape = grad_output.shape[1:] 27 | grad_output = rearrange(grad_output, 'b ... -> b (...)') 28 | grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]], 29 | device=grad_output.device, dtype=grad_output.dtype) 30 | # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. 31 | # grad_input[indices] = grad_output 32 | grad_input.scatter_(0, repeat(indices, 'z -> z d', d=grad_output.shape[1]), grad_output) 33 | return grad_input.reshape(ctx.first_axis_dim, *other_shape), None 34 | 35 | 36 | index_first_axis = IndexFirstAxis.apply 37 | 38 | 39 | class IndexPutFirstAxis(torch.autograd.Function): 40 | 41 | @staticmethod 42 | def forward(ctx, values, indices, first_axis_dim): 43 | ctx.save_for_backward(indices) 44 | assert indices.ndim == 1 45 | assert values.ndim >= 2 46 | output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, 47 | dtype=values.dtype) 48 | # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. 49 | output[indices] = values 50 | # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | indices, = ctx.saved_tensors 56 | # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. 57 | grad_values = grad_output[indices] 58 | # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) 59 | return grad_values, None, None 60 | 61 | 62 | index_put_first_axis = IndexPutFirstAxis.apply 63 | 64 | 65 | class IndexFirstAxisResidual(torch.autograd.Function): 66 | 67 | @staticmethod 68 | def forward(ctx, input, indices): 69 | ctx.save_for_backward(indices) 70 | assert input.ndim >= 2 71 | ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] 72 | second_dim = other_shape.numel() 73 | # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. 74 | output = input[indices] 75 | # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last 76 | # memory format to channel_first. In other words, input might not be contiguous. 77 | # If we don't detach, Pytorch complains about output being a view and is being modified inplace 78 | return output, input.detach() 79 | 80 | @staticmethod 81 | def backward(ctx, grad_output, grad_residual): 82 | indices, = ctx.saved_tensors 83 | assert grad_output.ndim >= 2 84 | other_shape = grad_output.shape[1:] 85 | assert grad_residual.shape[1:] == other_shape 86 | grad_input = grad_residual 87 | # grad_input[indices] += grad_output 88 | indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) 89 | indices = indices.expand_as(grad_output) 90 | grad_input.scatter_add_(0, indices, grad_output) 91 | return grad_input.reshape(ctx.first_axis_dim, *other_shape), None 92 | 93 | 94 | index_first_axis_residual = IndexFirstAxisResidual.apply 95 | 96 | 97 | def unpad_input(hidden_states, attention_mask): 98 | """ 99 | Arguments: 100 | hidden_states: (batch, seqlen, ...) 101 | attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. 102 | Return: 103 | hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. 104 | cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. 105 | max_seqlen_in_batch: int 106 | """ 107 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 108 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 109 | max_seqlen_in_batch = seqlens_in_batch.max().item() 110 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) 111 | # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the 112 | # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim 113 | # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to 114 | # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, 115 | # so we write custom forward and backward to make it a bit faster. 116 | return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices, 117 | cu_seqlens, max_seqlen_in_batch) 118 | 119 | 120 | def pad_input(hidden_states, indices, batch, seqlen): 121 | """ 122 | Arguments: 123 | hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. 124 | indices: (total_nnz) 125 | Return: 126 | hidden_states: (batch, seqlen, ...) 127 | """ 128 | dim = hidden_states.shape[-1] 129 | # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) 130 | # output[indices] = hidden_states 131 | output = index_put_first_axis(hidden_states, indices, batch * seqlen) 132 | return rearrange(output, '(b s) ... -> b s ...', b=batch) 133 | -------------------------------------------------------------------------------- /src/flash_bloom/bloom_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import transformers 7 | from transformers.models.bloom.modeling_bloom import dropout_add 8 | from flash_attn.modules.mha import FlashSelfAttention 9 | 10 | from einops import rearrange 11 | 12 | 13 | def forward( 14 | self, 15 | hidden_states: torch.Tensor, 16 | residual: torch.Tensor, 17 | alibi: torch.Tensor, 18 | attention_mask: torch.Tensor, 19 | layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 20 | head_mask: Optional[torch.Tensor] = None, 21 | use_cache: bool = False, 22 | output_attentions: bool = False, 23 | ): 24 | fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] 25 | 26 | # 3 x [batch_size, seq_length, num_heads, head_dim] 27 | (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) 28 | 29 | batch_size, q_length, _, _ = query_layer.shape 30 | 31 | query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) 32 | key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length) 33 | value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) 34 | if layer_past is not None: 35 | past_key, past_value = layer_past 36 | # concatenate along seq_length dimension: 37 | # - key: [batch_size * self.num_heads, head_dim, kv_length] 38 | # - value: [batch_size * self.num_heads, kv_length, head_dim] 39 | key_layer = torch.cat((past_key, key_layer), dim=2) 40 | value_layer = torch.cat((past_value, value_layer), dim=1) 41 | 42 | _, _, kv_length = key_layer.shape 43 | 44 | assert not output_attentions, "output_attentions is not supported" 45 | assert not use_cache, "use_cache is not supported" 46 | 47 | if use_cache is True: 48 | present = (key_layer, value_layer) 49 | else: 50 | present = None 51 | 52 | # [batch_size * num_heads, q_length, kv_length] 53 | # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 54 | # matmul_result = alibi.baddbmm( 55 | # batch1=query_layer, 56 | # batch2=key_layer, 57 | # beta=self.beta, 58 | # alpha=self.inv_norm_factor, 59 | # ) 60 | # 61 | # # change view to [batch_size, num_heads, q_length, kv_length] 62 | # attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) 63 | # 64 | # # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] 65 | # input_dtype = attention_scores.dtype 66 | # # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` 67 | # if input_dtype == torch.float16: 68 | # attention_scores = attention_scores.to(torch.float) 69 | # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) 70 | # attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) 71 | # 72 | # # [batch_size, num_heads, q_length, kv_length] 73 | # attention_probs = self.attention_dropout(attention_probs) 74 | # 75 | # if head_mask is not None: 76 | # attention_probs = attention_probs * head_mask 77 | # 78 | # # change view [batch_size x num_heads, q_length, kv_length] 79 | # attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) 80 | # 81 | # # matmul: [batch_size * num_heads, q_length, head_dim] 82 | # context_layer = torch.bmm(attention_probs_reshaped, value_layer) 83 | # 84 | # # change view [batch_size, num_heads, q_length, head_dim] 85 | # context_layer = self._merge_heads(context_layer) 86 | # 87 | # # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 88 | # if self.pretraining_tp > 1 and self.slow_but_exact: 89 | # slices = self.hidden_size / self.pretraining_tp 90 | # output_tensor = torch.zeros_like(context_layer) 91 | # for i in range(self.pretraining_tp): 92 | # output_tensor = output_tensor + F.linear( 93 | # context_layer[:, :, int(i * slices): int((i + 1) * slices)], 94 | # self.dense.weight[:, int(i * slices): int((i + 1) * slices)], 95 | # ) 96 | # else: 97 | # output_tensor = self.dense(context_layer) 98 | # 99 | # output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) 100 | 101 | # ===================process to flash accept form ================# 102 | reshaped_query_layer = query_layer.reshape(batch_size, self.num_heads, query_layer.shape[1], 103 | query_layer.shape[2]).permute(0, 2, 1, 3) 104 | reshaped_key_layer = key_layer.reshape(batch_size, self.num_heads, key_layer.shape[1], 105 | key_layer.shape[2]).permute(0, 3, 1, 2) 106 | reshaped_value_layer = value_layer.reshape(batch_size, self.num_heads, value_layer.shape[1], 107 | value_layer.shape[2]).permute(0, 2, 1, 3) 108 | offset_key_layer = self.inv_norm_factor * reshaped_key_layer + self.beta * ( 109 | torch.linalg.pinv(reshaped_query_layer.permute(0, 2, 1, 3).float()) * alibi.view(batch_size, 110 | alibi.shape[ 111 | 0] // batch_size, 112 | alibi.shape[1], 113 | alibi.shape[ 114 | 2])).permute(0, 3, 115 | 1, 116 | 2).half() 117 | qkv = torch.concat( 118 | [reshaped_query_layer.unsqueeze(2), offset_key_layer.unsqueeze(2), reshaped_value_layer.unsqueeze(2)], 119 | dim=2).half() 120 | if not hasattr(self, 'flash_self_attention'): 121 | self.flash_self_attention = FlashSelfAttention(causal=True, softmax_scale=1) 122 | context_layer = self.flash_self_attention(qkv) 123 | context_layer = torch.flatten(context_layer, start_dim=2) 124 | 125 | # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 126 | if self.pretraining_tp > 1 and self.slow_but_exact: 127 | slices = self.hidden_size / self.pretraining_tp 128 | output_tensor = torch.zeros_like(context_layer) 129 | for i in range(self.pretraining_tp): 130 | output_tensor = output_tensor + torch.nn.functional.linear( 131 | context_layer[:, :, int(i * slices): int((i + 1) * slices)], 132 | self.dense.weight[:, int(i * slices): int((i + 1) * slices)], 133 | ) 134 | else: 135 | output_tensor = self.dense(context_layer) 136 | 137 | output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) 138 | 139 | outputs = (output_tensor, present) 140 | 141 | return outputs 142 | 143 | 144 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 145 | # requires the attention mask to be the same as the key_padding_mask 146 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, 147 | inputs_embeds, past_key_values_length): 148 | # [bsz, seq_len] 149 | return attention_mask 150 | 151 | 152 | def replace_bloom_attn_with_flash_attn(): 153 | # transformers.models.bloom.modeling_bloom.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 154 | transformers.models.bloom.modeling_bloom.BloomAttention.forward = forward 155 | -------------------------------------------------------------------------------- /src/flash_bloom/flash_attn_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb 3 | from typing import Optional 4 | 5 | from flash_attn.modules.mha import FlashSelfAttention 6 | from transformers.models.bloom.modeling_bloom import dropout_add 7 | from typing import Optional, Tuple 8 | 9 | 10 | class FlashAttentionWrapper(torch.nn.Module): 11 | def __init__(self, attention, max_seqlen=8190): 12 | super().__init__() 13 | self.attention = attention 14 | self.max_seqlen = max_seqlen 15 | self.flash_self_attention = FlashSelfAttention(causal=True, softmax_scale=1) 16 | self.dropout_p = 0.0 17 | 18 | def forward( 19 | self, 20 | hidden_states: torch.Tensor, 21 | key_value_states: Optional[torch.Tensor] = None, 22 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 23 | attention_mask: Optional[torch.Tensor] = None, 24 | layer_head_mask: Optional[torch.Tensor] = None, 25 | output_attentions: bool = False, 26 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 27 | """Input shape: Batch x Time x Channel""" 28 | 29 | # if key_value_states are provided this layer is used as a cross-attention layer 30 | # for the decoder 31 | is_cross_attention = key_value_states is not None 32 | 33 | bsz, tgt_len, _ = hidden_states.size() 34 | 35 | # get query proj 36 | query_states = self.attention.q_proj(hidden_states) * self.attention.scaling 37 | # get key, value proj 38 | if is_cross_attention and past_key_value is not None: 39 | # reuse k,v, cross_attentions 40 | key_states = past_key_value[0] 41 | value_states = past_key_value[1] 42 | elif is_cross_attention: 43 | # cross_attentions 44 | key_states = self.attention._shape(self.attention.k_proj(key_value_states), -1, bsz) 45 | value_states = self.attention._shape(self.attention.v_proj(key_value_states), -1, bsz) 46 | elif past_key_value is not None: 47 | # reuse k, v, self_attention 48 | key_states = self.attention._shape(self.attention.k_proj(hidden_states), -1, bsz) 49 | value_states = self.attention._shape(self.attention.v_proj(hidden_states), -1, bsz) 50 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 51 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 52 | else: 53 | # self_attention 54 | key_states = self.attention._shape(self.attention.k_proj(hidden_states), -1, bsz) 55 | value_states = self.attention._shape(self.attention.v_proj(hidden_states), -1, bsz) 56 | 57 | if self.attention.is_decoder: 58 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 59 | # Further calls to cross_attention layer can then reuse all cross-attention 60 | # key/value_states (first "if" case) 61 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 62 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 63 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 64 | # if encoder bi-directional self-attention `past_key_value` is always `None` 65 | past_key_value = (key_states, value_states) 66 | 67 | proj_shape = (bsz, self.attention.num_heads, -1, self.attention.head_dim) 68 | key_states = key_states.view(*proj_shape).permute(0, 2, 1, 3) 69 | query_states = self.attention._shape(query_states, tgt_len, bsz).permute(0, 2, 1, 3) 70 | value_states = value_states.view(*proj_shape).permute(0, 2, 1, 3) 71 | qkv = torch.concat([query_states.unsqueeze(2), key_states.unsqueeze(2), value_states.unsqueeze(2)], 72 | dim=2).half() 73 | attn_output = self.flash_self_attention(qkv) 74 | attn_weights_reshaped = None 75 | 76 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 77 | # partitioned aross GPUs when using tensor-parallelism. 78 | attn_output = attn_output.reshape(bsz, tgt_len, self.attention.embed_dim) 79 | 80 | attn_output = self.attention.out_proj(attn_output) 81 | 82 | return attn_output, attn_weights_reshaped, past_key_value 83 | 84 | 85 | class FlashAttentionWrapperWithRotary(torch.nn.Module): 86 | def __init__(self, attention, max_seqlen=8192): 87 | super().__init__() 88 | self.attention = attention 89 | self.max_seqlen = max_seqlen 90 | self.flash_self_attention = FlashSelfAttention(causal=True, softmax_scale=1 / self.attention.norm_factor) 91 | self.dropout_p = 0.0 92 | 93 | def forward(self, 94 | hidden_states, 95 | attention_mask, 96 | head_mask=None, 97 | layer_past=None, 98 | use_cache=False, 99 | output_attentions=False): 100 | has_layer_past = layer_past is not None 101 | 102 | # Compute QKV 103 | # Attention heads [batch, seq_len, hidden_size] 104 | # --> [batch, seq_len, (np * 3 * head_size)] 105 | qkv = self.attention.query_key_value(hidden_states) 106 | 107 | # [batch, seq_len, (num_heads * 3 * head_size)] 108 | # --> [batch, seq_len, num_heads, 3 * head_size] 109 | new_qkv_shape = qkv.size()[:-1] + (self.attention.num_attention_heads, 3 * self.attention.head_size) 110 | qkv = qkv.view(*new_qkv_shape) 111 | 112 | # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] 113 | query = qkv[..., : self.attention.head_size].permute(0, 2, 1, 3) 114 | key = qkv[..., self.attention.head_size: 2 * self.attention.head_size].permute(0, 2, 1, 3) 115 | value = qkv[..., 2 * self.attention.head_size:].permute(0, 2, 1, 3) 116 | 117 | # Compute rotary embeddings on rotary_ndims 118 | query_rot = query[..., : self.attention.rotary_ndims] 119 | query_pass = query[..., self.attention.rotary_ndims:] 120 | key_rot = key[..., : self.attention.rotary_ndims] 121 | key_pass = key[..., self.attention.rotary_ndims:] 122 | 123 | # Compute token offset for rotary embeddings (when decoding) 124 | seq_len = key.shape[-2] 125 | offset = 0 126 | if has_layer_past: 127 | offset = layer_past[0].shape[-2] 128 | seq_len += offset 129 | cos, sin = self.attention.rotary_emb(value, seq_len=seq_len) 130 | query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset) 131 | query = torch.cat((query, query_pass), dim=-1) 132 | key = torch.cat((key, key_pass), dim=-1) 133 | 134 | # Cache QKV values 135 | if has_layer_past: 136 | past_key = layer_past[0] 137 | past_value = layer_past[1] 138 | key = torch.cat((past_key, key), dim=-2) 139 | value = torch.cat((past_value, value), dim=-2) 140 | present = (key, value) if use_cache else None 141 | 142 | # Compute attention 143 | # attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 144 | 145 | qkv = torch.concat([query.unsqueeze(2), key.unsqueeze(2), value.unsqueeze(2)], dim=2).permute(0, 3, 2, 1, 146 | 4).half() 147 | attn_output = self.flash_self_attention(qkv) 148 | attn_weights = None 149 | 150 | # Reshape outputs 151 | attn_output = attn_output.view(attn_output.size(0), attn_output.size(1), 152 | self.attention.num_attention_heads * self.attention.head_size) 153 | attn_output = self.attention.dense(attn_output) 154 | 155 | outputs = (attn_output, present) 156 | if output_attentions: 157 | outputs += (attn_weights,) 158 | 159 | return outputs 160 | 161 | 162 | class FlashAttentionWrapperWithAlibi(torch.nn.Module): 163 | def __init__(self, attention, max_seqlen=8192): 164 | super().__init__() 165 | self.attention = attention 166 | self.max_seqlen = max_seqlen 167 | self.flash_self_attention = FlashSelfAttention(causal=True, softmax_scale=None) 168 | self.dropout_p = 0.0 169 | 170 | def forward(self, 171 | hidden_states: torch.Tensor, 172 | residual: torch.Tensor, 173 | alibi: torch.Tensor, 174 | attention_mask: torch.Tensor, 175 | layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 176 | head_mask: Optional[torch.Tensor] = None, 177 | use_cache: bool = False, 178 | output_attentions: bool = False, 179 | ): 180 | fused_qkv = self.attention.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] 181 | 182 | # 3 x [batch_size, seq_length, num_heads, head_dim] 183 | (query_layer, key_layer, value_layer) = self.attention._split_heads(fused_qkv) 184 | batch_size, q_length, _, _ = query_layer.shape 185 | 186 | query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.attention.num_heads, q_length, 187 | self.attention.head_dim) 188 | key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.attention.num_heads, 189 | self.attention.head_dim, q_length) 190 | value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.attention.num_heads, q_length, 191 | self.attention.head_dim) 192 | 193 | if layer_past is not None: 194 | past_key, past_value = layer_past 195 | # concatenate along seq_length dimension: 196 | # - key: [batch_size * self.num_heads, head_dim, kv_length] 197 | # - value: [batch_size * self.num_heads, kv_length, head_dim] 198 | key_layer = torch.cat((past_key, key_layer), dim=2) 199 | value_layer = torch.cat((past_value, value_layer), dim=1) 200 | 201 | if use_cache is True: 202 | present = (key_layer, value_layer) 203 | else: 204 | present = None 205 | 206 | reshaped_query_layer = query_layer.reshape(batch_size, self.attention.num_heads, query_layer.shape[1], 207 | query_layer.shape[2]).permute(0, 2, 1, 3) 208 | reshaped_key_layer = key_layer.reshape(batch_size, self.attention.num_heads, key_layer.shape[1], 209 | key_layer.shape[2]).permute(0, 3, 1, 2) 210 | reshaped_value_layer = value_layer.reshape(batch_size, self.attention.num_heads, value_layer.shape[1], 211 | value_layer.shape[2]).permute(0, 2, 1, 3) 212 | offset_key_layer = self.attention.inv_norm_factor * reshaped_key_layer + self.attention.beta * ( 213 | torch.linalg.pinv(reshaped_query_layer.permute(0, 2, 1, 3).float()) * alibi.view(batch_size, 214 | alibi.shape[ 215 | 0] // batch_size, 216 | alibi.shape[1], 217 | alibi.shape[ 218 | 2])).permute(0, 219 | 3, 220 | 1, 221 | 2).half() 222 | qkv = torch.concat( 223 | [reshaped_query_layer.unsqueeze(2), offset_key_layer.unsqueeze(2), reshaped_value_layer.unsqueeze(2)], 224 | dim=2).half() 225 | context_layer = self.flash_self_attention(qkv) 226 | context_layer = torch.flatten(context_layer, start_dim=2) 227 | 228 | # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 229 | if self.attention.pretraining_tp > 1 and self.attention.slow_but_exact: 230 | slices = self.attention.hidden_size / self.attention.pretraining_tp 231 | output_tensor = torch.zeros_like(context_layer) 232 | for i in range(self.attention.pretraining_tp): 233 | output_tensor = output_tensor + F.linear( 234 | context_layer[:, :, int(i * slices): int((i + 1) * slices)], 235 | self.attention.dense.weight[:, int(i * slices): int((i + 1) * slices)], 236 | ) 237 | else: 238 | output_tensor = self.attention.dense(context_layer) 239 | 240 | output_tensor = dropout_add(output_tensor, residual, self.attention.hidden_dropout, self.attention.training) 241 | 242 | outputs = (output_tensor, present) 243 | return outputs 244 | -------------------------------------------------------------------------------- /src/flash_bloom/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 8 | 9 | from einops import rearrange 10 | 11 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 12 | from flash_attn.bert_padding import unpad_input, pad_input 13 | 14 | 15 | def forward( 16 | self, 17 | hidden_states: torch.Tensor, 18 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.LongTensor] = None, 21 | output_attentions: bool = False, 22 | use_cache: bool = False, 23 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], 24 | Optional[Tuple[torch.Tensor]]]: 25 | """Input shape: Batch x Time x Channel 26 | 27 | attention_mask: [bsz, q_len] 28 | """ 29 | bsz, q_len, _ = hidden_states.size() 30 | 31 | query_states = self.q_proj(hidden_states).view( 32 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 33 | key_states = self.k_proj(hidden_states).view( 34 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 35 | value_states = self.v_proj(hidden_states).view( 36 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 37 | # [bsz, q_len, nh, hd] 38 | # [bsz, nh, q_len, hd] 39 | 40 | kv_seq_len = key_states.shape[-2] 41 | if past_key_value is not None: 42 | kv_seq_len += past_key_value[0].shape[-2] 43 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 44 | query_states, key_states = apply_rotary_pos_emb(query_states, 45 | key_states, 46 | cos, 47 | sin, 48 | position_ids=position_ids) 49 | # [bsz, nh, t, hd] 50 | # use_cache=False 51 | assert not output_attentions, "output_attentions is not supported" 52 | assert not use_cache, "use_cache is not supported" 53 | assert past_key_value is None, "past_key_value is not supported" 54 | 55 | # Flash attention codes from 56 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 57 | 58 | # transform the data into the format required by flash attention 59 | qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd] 60 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 61 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 62 | # the attention_mask should be the same as the key_padding_mask 63 | key_padding_mask = attention_mask 64 | 65 | if key_padding_mask is None: 66 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 67 | max_s = q_len 68 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, 69 | device=qkv.device) 70 | output = flash_attn_unpadded_qkvpacked_func( 71 | qkv, cu_q_lens, max_s, 0.0, 72 | softmax_scale=None, causal=True 73 | ) 74 | output = rearrange(output, '(b s) ... -> b s ...', b=bsz) 75 | else: 76 | nheads = qkv.shape[-2] 77 | x = rearrange(qkv, 'b s three h d -> b s (three h d)') 78 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 79 | x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) 80 | output_unpad = flash_attn_unpadded_qkvpacked_func( 81 | x_unpad, cu_q_lens, max_s, 0.0, 82 | softmax_scale=None, causal=True 83 | ) 84 | output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), 85 | indices, bsz, q_len), 86 | 'b s (h d) -> b s h d', h=nheads) 87 | return self.o_proj(rearrange(output, 88 | 'b s h d -> b s (h d)')), None, None 89 | 90 | 91 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 92 | # requires the attention mask to be the same as the key_padding_mask 93 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, 94 | inputs_embeds, past_key_values_length): 95 | # [bsz, seq_len] 96 | return attention_mask 97 | 98 | 99 | def replace_llama_attn_with_flash_attn(): 100 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 101 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 102 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/basic_pl_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | @Notes: 5 | """ 6 | 7 | import sys 8 | import glob 9 | import os 10 | from pathlib import Path 11 | 12 | import pytorch_lightning as pl 13 | 14 | from src.configuration.task.config_args import parse_args_for_config 15 | from src.utils.wrapper import print_done 16 | from src.utils.string_utils import are_same_strings 17 | from src.models.basic_trainer import BasicTrainer 18 | from src.modules.pl_callbacks import Seq2SeqLoggingCallback, Seq2SeqCheckpointCallback 19 | 20 | 21 | class BasicPLTrainer(BasicTrainer): 22 | def __init__(self, args, trainer_name="basic-pl-trainer"): 23 | # parameters 24 | super().__init__(args, trainer_name=trainer_name) 25 | 26 | # customized variables 27 | self.pl_trainer = None 28 | self.device = self.model.device if hasattr(self.model, "device") else self.device 29 | 30 | @print_done(desc="Creating directories and fix random seeds") 31 | def _init_args(self, args): 32 | self.output_dir.mkdir(parents=True, exist_ok=True) 33 | self.experiment_output_dir.mkdir(parents=True, exist_ok=True) 34 | self.log_dir.mkdir(parents=True, exist_ok=True) 35 | self.save_dir.mkdir(parents=True, exist_ok=True) 36 | self.cache_dir.mkdir(parents=True, exist_ok=True) 37 | 38 | pl.seed_everything(args.seed, workers=True) # reproducibility 39 | 40 | @print_done(desc="initialize model") 41 | def _init_model(self, args): 42 | # automatically download from huggingface 43 | print(f"model_path: {args.model_name_or_path}") 44 | raise NotImplementedError(f"args.model_name: {args.model_name}") 45 | 46 | @print_done("set up pytorch lightning trainer") 47 | def _init_pl_trainer(self, args, model, logger): 48 | extra_callbacks = [] 49 | self.checkpoint_callback = Seq2SeqCheckpointCallback( 50 | output_dir=self.save_dir, 51 | experiment_name=self.experiment_name, 52 | monitor="val_loss", 53 | save_top_k=args.save_top_k, 54 | every_n_train_steps=args.every_n_train_steps, 55 | verbose=args.ckpt_verbose, 56 | ) 57 | 58 | # initialize pl_trainer 59 | if args.gpus is not None and args.gpus > 1: 60 | self.train_params["distributed_backend"] = "ddp" 61 | 62 | self.pl_trainer: pl.Trainer = pl.Trainer.from_argparse_args( 63 | args, 64 | enable_model_summary=False, 65 | callbacks=[self.checkpoint_callback, Seq2SeqLoggingCallback(), pl.callbacks.ModelSummary(max_depth=1)] 66 | + extra_callbacks, 67 | logger=logger, 68 | **self.train_params, 69 | ) 70 | 71 | @property 72 | def checkpoints(self): 73 | return list(sorted(glob.glob(os.path.join(self.save_dir, "*.ckpt"), recursive=True))) 74 | 75 | def auto_find_lr_rate(self): 76 | if self.pl_trainer.auto_lr_find: 77 | self.pl_trainer.tune(self.model) 78 | print(f"after tuning: {self.model.learning_rate}") 79 | 80 | # 开始寻找合适的学习率 81 | lr_finder = self.model.tuner.lr_find(self.model) 82 | # 展示loss和学习率的曲线 83 | print(f"auto find the best learning rate of model {self.model.model_name}:\n{lr_finder.results}") 84 | # 设置为推荐的学习率 85 | suggested_lr = lr_finder.suggestion() 86 | print(f"the suggested lr: {suggested_lr}") 87 | self.model.hyparams.learning_rate = suggested_lr 88 | 89 | def auto_find_batch_size(self): 90 | if self.pl_trainer.auto_scale_batch_size == "binsearch": 91 | self.pl_trainer.tune(self.model) 92 | print(f"auto find the best of batch size of {self.model.model_name}:\n{self.pl_trainer.batch_size}") 93 | self.model.hyparams.train_batch_size = self.model.batch_size 94 | 95 | def train(self): 96 | self.auto_find_lr_rate() 97 | self.auto_find_batch_size() 98 | 99 | self.pl_trainer.logger.log_hyperparams(self.args) 100 | 101 | if self.checkpoints: 102 | # training 103 | best_ckpt = self.checkpoints[-1] 104 | self.pl_trainer.fit(self.model, ckpt_path=best_ckpt) 105 | else: 106 | # training 107 | if hasattr(self.model, "init_for_vanilla_weights"): 108 | self.model.init_for_vanilla_weights() 109 | self.pl_trainer.fit(self.model) 110 | -------------------------------------------------------------------------------- /src/models/basic_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | @Notes: 5 | """ 6 | 7 | import os 8 | import sys 9 | import torch 10 | import random 11 | import numpy as np 12 | from pathlib import Path 13 | 14 | FILE_PATH = Path(__file__).absolute() 15 | BASE_DIR = FILE_PATH.parent.parent.parent 16 | sys.path.insert(0, str(BASE_DIR)) # 在tasks文件夹中可以直接运行程序 17 | 18 | from pytorch_lightning.loggers import CSVLogger 19 | from src.utils.wrapper import print_done 20 | 21 | 22 | class BasicTrainer(object): 23 | def __init__(self, args, trainer_name: str = "basic_trainer"): 24 | # parameters 25 | self.trainer_name = trainer_name 26 | self.args = args 27 | self.output_dir = Path(args.output_dir) 28 | self.experiment_name = args.experiment_name 29 | self.experiment_output_dir = self.output_dir.joinpath(self.experiment_name) 30 | self.save_dir = self.experiment_output_dir.joinpath("checkpoints") 31 | self.data_dir = args.data_dir 32 | self.log_dir = self.experiment_output_dir.joinpath("logs") 33 | self.cache_dir = self.experiment_output_dir.joinpath("cache") 34 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 35 | self.n_gpu = torch.cuda.device_count() 36 | self.logger = None 37 | self.tokenizer = None 38 | self.vocab = None 39 | self.train_loader = None 40 | self.dev_loader = None 41 | self.test_loader = None 42 | self.criterion = None 43 | self.model = None 44 | self.train_params = {} 45 | 46 | self._init_args(self.args) 47 | 48 | def _init_logger(self, args, model): 49 | """ 50 | - WandbLogger 51 | name: Display name for the run. 52 | project: The name of the project to which this run will belong. 53 | """ 54 | if args.logger_name == "CSVLogger": 55 | self.logger = CSVLogger(save_dir=self.log_dir, name=f'{self.experiment_name}_CSVLogger') 56 | elif args.logger_name == "WandbLogger": 57 | from pytorch_lightning.loggers import WandbLogger 58 | os.environ["WANDB_API_KEY"] = 'xxxxxxxxxxxxxx' # TODO your api key 59 | 60 | self.logger = WandbLogger(name=f'{self.experiment_name}_WandLogger', project=self.experiment_name) 61 | else: 62 | raise NotImplementedError 63 | 64 | @print_done(desc="Creating directories and fix random seeds") 65 | def _init_args(self, args): 66 | self.output_dir.mkdir(parents=True, exist_ok=True) 67 | self.experiment_output_dir.mkdir(parents=True, exist_ok=True) 68 | self.log_dir.mkdir(parents=True, exist_ok=True) 69 | self.save_dir.mkdir(parents=True, exist_ok=True) 70 | self.cache_dir.mkdir(parents=True, exist_ok=True) 71 | 72 | random.seed(args.seed) 73 | np.random.seed(args.seed) 74 | torch.manual_seed(args.seed) 75 | torch.cuda.manual_seed_all(args.seed) 76 | -------------------------------------------------------------------------------- /src/models/lightning_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | - from_argparse_args: 5 | https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.utilities.argparse.html#p 6 | ytorch_lightning.utilities.argparse.from_argparse_args 7 | - ModelCheckpoint 8 | https://pytorch-lightning.readthedocs.io/en/latest/extensions/generated/ 9 | pytorch_lightning.callbacks.ModelCheckpoint.html?highlight=ModelCheckpoint 10 | - Trainer.from_argparse_args(args) 11 | https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html 12 | - Optimizers: AdamW and AdaFactor 13 | https://huggingface.co/docs/transformers/main_classes/optimizer_schedules 14 | Adaptive Learning Rates with Sublinear Memory Cost https://arxiv.org/abs/1804.04235 15 | - optimizer_grouped_parameters 16 | https://huggingface.co/transformers/v3.3.1/training.html 17 | The optimizer allows us to apply different hyperpameters for specific parameter groups. 18 | - rank_zero_only 19 | http://www.liuxiao.org/2020/07/pytorch-lighting-%E5%B8%B8%E8%A7%81%E9%97%AE%E9%A2%98%E6%95%B4%E7%90%86/ 20 | 使用 @rank_zero_only 修饰多线程中只在 RANK=0 调用的函数 21 | - save vocab.json 22 | https://huggingface.co/transformers/v1.0.0/model_doc/overview.html 23 | @Notes: 24 | - huggingface from_pretrained default store path 25 | windows: ~/.cache/huggingface 26 | - architecture 27 | model.prepare_data() 28 | initialize_distributed() 29 | model.setup(stage) 30 | model.train_dataloader() 31 | model.val_dataloader() 32 | model.test_dataloader() 33 | """ 34 | 35 | import argparse 36 | import logging 37 | import torch 38 | import pytorch_lightning as pl 39 | 40 | from pathlib import Path 41 | from typing import Any, Dict, Optional 42 | from pytorch_lightning.utilities import rank_zero_info 43 | 44 | from transformers import ( 45 | AutoConfig, 46 | AutoTokenizer, 47 | PretrainedConfig, 48 | PreTrainedTokenizer, 49 | ) 50 | 51 | from transformers.optimization import ( 52 | Adafactor, 53 | ) 54 | 55 | from src.configuration.constants import MODEL_CLASSES, GET_SCHEDULER_FUNCS 56 | from src.utils.string_utils import are_same_strings 57 | 58 | logger = logging.getLogger(__name__) 59 | 60 | 61 | class BaseTransformer(pl.LightningModule): 62 | loss_names = {"loss", } 63 | metric_names = {"loss", } 64 | 65 | def __init__( 66 | self, 67 | hparams: argparse.Namespace, 68 | model_type="base", 69 | **kwargs 70 | ): 71 | """Initialize a model, tokenizer and config.""" 72 | super().__init__() 73 | self.save_hyperparameters(hparams) # save hparams to self.hparams 74 | self.output_dir = Path(self.hparams.output_dir) 75 | self.experiment_name = self.hparams.experiment_name 76 | self.model_name = self.hparams.model_name 77 | self.experiment_output_dir = self.output_dir.joinpath(self.hparams.experiment_name) 78 | self.pretrained_save_path = Path(self.experiment_output_dir).joinpath("best_tfmr") 79 | self.model_type = model_type 80 | self.batch_size = None # for auto_scale_batch_size 81 | # record api 82 | self.config = None 83 | self.tokenizer = None 84 | self.model = None 85 | self.optimizer = None 86 | self.scheduler = None 87 | 88 | def _set_up(self, 89 | config: PretrainedConfig = None, 90 | tokenizer: PreTrainedTokenizer = None, 91 | model=None, **config_kwargs): 92 | # load pretrained settings 93 | # config 94 | self.config: PretrainedConfig = config if config is not None else \ 95 | AutoConfig.from_pretrained(self.hparams.model_name_or_path, 96 | **config_kwargs) 97 | self._check_config(self.config) 98 | # tokenizer 99 | self.tokenizer: PreTrainedTokenizer = tokenizer if tokenizer is not None else \ 100 | AutoTokenizer.from_pretrained(self.hparams.model_name_or_path) 101 | # model 102 | self.model_class = MODEL_CLASSES[self.model_type] 103 | self.model = model if model is not None \ 104 | else self._load_model(self.hparams.model_name_or_path, self.model_class, config) 105 | 106 | @property 107 | def vocab_size(self): 108 | return len(self.tokenizer) 109 | 110 | def _check_config(self, config: PretrainedConfig): 111 | extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") 112 | for p in extra_model_params: 113 | if getattr(self.hparams, p, None): 114 | assert hasattr(config, p), f"model config doesn't have a `{p}` attribute" 115 | setattr(config, p, getattr(self.hparams, p)) 116 | 117 | def _load_model(self, model_name_or_path: str, model_class, config: PretrainedConfig = None, cache_dir=None): 118 | if config is None: 119 | return model_class.from_pretrained( 120 | model_name_or_path, 121 | cache_dir=cache_dir, 122 | trust_remote_code=True 123 | ) 124 | else: 125 | return model_class.from_pretrained( 126 | model_name_or_path, 127 | cache_dir=cache_dir, 128 | config=config, 129 | trust_remote_code=True 130 | ) 131 | 132 | def get_lr_scheduler(self, optimizer: torch.optim.Optimizer, frequency=1): 133 | scheduler = None 134 | if self.hparams.lr_scheduler in GET_SCHEDULER_FUNCS: 135 | get_schedule_func = GET_SCHEDULER_FUNCS[self.hparams.lr_scheduler] 136 | scheduler = get_schedule_func(optimizer, num_warmup_steps=self.hparams.warmup_steps, 137 | num_training_steps=self.total_steps()) 138 | scheduler = {"scheduler": scheduler, "interval": "step", "frequency": frequency} 139 | elif are_same_strings(self.hparams.lr_scheduler, "ReduceLROnPlateau"): 140 | scheduler = {"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer), 141 | "interval": "epoch", "frequency": frequency, "monitor": "val_loss"} 142 | else: 143 | raise NotImplementedError() 144 | return scheduler 145 | 146 | def configure_optimizers(self): 147 | """Prepare optimizer and schedule (linear warmup and decay)""" 148 | model = self.model 149 | no_decay = ["bias", "LayerNorm.weight"] 150 | optimizer_grouped_parameters = [ 151 | { 152 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 153 | "weight_decay": self.hparams.weight_decay, 154 | }, 155 | { 156 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 157 | "weight_decay": 0.0, 158 | }, 159 | ] 160 | if are_same_strings(self.hparams.optimizer_class, "Adafactor"): 161 | optimizer = Adafactor( 162 | optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False 163 | ) 164 | 165 | elif are_same_strings(self.hparams.optimizer_class, "AdamW"): 166 | optimizer = torch.optim.AdamW( 167 | optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon 168 | ) 169 | else: 170 | raise NotImplementedError(f"{self.hparams.optimizer_class} not available. Only Adafactor|Adafactor") 171 | 172 | self.optimizer = optimizer 173 | 174 | frequency = 1 175 | self.scheduler = self.get_lr_scheduler(self.optimizer, frequency=frequency) 176 | 177 | return { 178 | "optimizer": optimizer, 179 | "lr_scheduler": self.scheduler 180 | } 181 | 182 | def test_step(self, batch, batch_nb): 183 | return self.validation_step(batch, batch_nb) 184 | 185 | def test_epoch_end(self, outputs): 186 | return self.validation_end(outputs) 187 | 188 | def total_steps(self) -> int: 189 | """The number of total training steps that will be run. Used for lr scheduler purposes.""" 190 | num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores 191 | if isinstance(self.hparams.accumulate_grad_batches, dict): 192 | accumulate_grad_batches = list(self.hparams.accumulate_grad_batches.values())[-1] 193 | else: 194 | accumulate_grad_batches = self.hparams.accumulate_grad_batches 195 | effective_batch_size = self.hparams.train_batch_size * accumulate_grad_batches * num_devices 196 | return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs 197 | 198 | def setup(self, stage: Optional[str] = None): 199 | # stage: train, test, eval 200 | self.dataset_size = len(self.train_dataloader().dataset) 201 | 202 | def get_dataloader(self, data_type: str, batch_size: int, shuffle: bool = False): 203 | raise NotImplementedError("You must implement this for your task") 204 | 205 | def train_dataloader(self): 206 | return self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) 207 | 208 | def val_dataloader(self): 209 | return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size, shuffle=False) 210 | 211 | def test_dataloader(self): 212 | return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size, shuffle=False) 213 | 214 | @pl.utilities.rank_zero_only 215 | def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 216 | save_path = self.pretrained_save_path 217 | self.model.config.save_step = self.step_count 218 | self.model.save_pretrained(save_path) 219 | self.tokenizer.save_pretrained(save_path) 220 | print(f"hugging face format checkpoint save at {save_path}") 221 | 222 | def update_loss_names(self, loss_result: Dict, update_flag=True): 223 | if not update_flag: 224 | return 225 | for loss_name_ in loss_result.keys(): 226 | if loss_name_ not in self.loss_names: 227 | self.loss_names.add(loss_name_) 228 | 229 | def update_metric_names(self, metrics: Dict, update_flag=True): 230 | if not update_flag: 231 | return 232 | for metric_name_ in metrics.keys(): 233 | if metric_name_ not in self.metric_names: 234 | self.metric_names.add(metric_name_) 235 | -------------------------------------------------------------------------------- /src/models/task/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpt_code import ( 2 | gpt_code, 3 | ) 4 | from .llama_finetune import llama_finetune 5 | -------------------------------------------------------------------------------- /src/models/task/gpt_code.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | - BART shift_tokens_right 5 | https://huggingface.co/docs/transformers/v4.17.0/en/model_doc/bart#bart 6 | - Label Smoothing 7 | https://paperswithcode.com/method/label-smoothing 8 | - bart models from huggingface 9 | e.g. https://huggingface.co/facebook/bart-base 10 | @Notes: 11 | - BART shift_tokens_right 12 | Bart uses the eos_token_id as the starting token for decoder_input_ids generation. 13 | If past_key_values is used, optionally only the last decoder_input_ids have to be input (see past_key_values). 14 | For translation and summarization training, decoder_input_ids should be provided. If no decoder_input_ids is provided, 15 | the model will create this tensor by shifting the input_ids to the right for denoising pre-training following the paper. 16 | - label-smoothing 17 | During finetuning we use a label smoothed cross entropy loss (Pereyra et al., 2017), with the smoothing parameter 18 | set to 0.1. 19 | - model generate: 20 | in generation_utils.py e.g.BartForConditionalGeneration().generate -> def generate in generation_utils.py 21 | - torch.nn.CrossEntropyLoss 22 | Shape: 23 | - Input: :math:`(N, C)` where `C = number of classes`, or 24 | :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` 25 | in the case of `K`-dimensional loss. 26 | - Target: If containing class indices, shape :math:`(N)` where each value is 27 | :math:`0 \leq \text{targets}[i] \leq C-1`, or :math:`(N, d_1, d_2, ..., d_K)` with 28 | :math:`K \geq 1` in the case of K-dimensional loss. If containing class probabilities, 29 | same shape as the input. 30 | """ 31 | 32 | import logging 33 | import numpy as np 34 | import torch 35 | 36 | from datetime import datetime 37 | from collections import defaultdict 38 | from pathlib import Path 39 | from typing import Dict 40 | from torch.utils.data import DataLoader 41 | from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer 42 | from src.configuration.constants import DATASET_CLASSES 43 | from src.utils.task import model_utils 44 | from src.models.lightning_base import BaseTransformer 45 | 46 | logger = logging.getLogger(__name__) 47 | 48 | 49 | class gpt_code(BaseTransformer): 50 | def __init__(self, hparams, **kwargs): 51 | super().__init__(hparams, **kwargs) 52 | self._custom_init() 53 | 54 | # Whether changing embeddings 55 | if self.hparams.freeze_embeds: 56 | model_utils.freeze_embeds(self.model) 57 | if self.hparams.freeze_encoder: 58 | model_utils.freeze_params(self.model.get_encoder()) 59 | model_utils.assert_all_frozen(self.model.get_encoder()) 60 | 61 | self.step_count = 0 62 | self.current_val_metrics = {} 63 | self.metrics_save_path = Path(self.experiment_output_dir) / "metrics.json" 64 | self.metrics: dict = defaultdict(list) 65 | self.model_type = self.config.model_type 66 | self.decoder_start_token_id = self.model.config.decoder_start_token_id # default to config 67 | self.already_saved_batch = False # flag of saving readable batch 68 | self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams 69 | self.val_metric = "loss" if self.hparams.val_metric is None else self.hparams.val_metric 70 | self.save_readable_batch = True # for debug 71 | self.metric_names_update_flag = True 72 | 73 | # predicted 74 | self.use_top_p = False 75 | self.top_p = 0.9 76 | self.store_test_output = True 77 | self.test_output = None 78 | self.remain_sp_tokens = self.hparams.remain_sp_tokens 79 | if self.remain_sp_tokens: 80 | print("remain special tokens in target and pred text (e.g. [EVENT_s])") 81 | 82 | @property 83 | def pad_token_id(self) -> int: 84 | return self.tokenizer.pad_token_id 85 | 86 | def _custom_init(self): 87 | # config 88 | self.config = AutoConfig.from_pretrained(self.hparams.model_name_or_path, trust_remote_code=True) 89 | # tokenizer 90 | self.tokenizer = AutoTokenizer.from_pretrained(self.hparams.model_name_or_path, trust_remote_code=True) 91 | # model 92 | self.model = self._load_model(self.hparams.model_name_or_path, AutoModelForCausalLM, self.config) 93 | 94 | self._set_up(config=self.config, tokenizer=self.tokenizer, model=self.model) 95 | self.train_dataset_class = DATASET_CLASSES[self.hparams.training_stage] 96 | self.test_dataset_class = DATASET_CLASSES[self.hparams.training_stage] 97 | 98 | def forward(self, input_ids, **kwargs): 99 | return self.model(input_ids, **kwargs) 100 | 101 | def _step(self, batch: dict): 102 | outputs = self.model(**batch) 103 | return outputs 104 | 105 | def training_step(self, batch, batch_idx) -> Dict: 106 | loss = self._step(batch)['loss'] 107 | if torch.isnan(loss): 108 | print(batch) 109 | self.log("train_loss", loss, prog_bar=True, sync_dist=True) # metrics logged can be access by trainer.callback_metrics 110 | logs = {"loss": loss.item(), "batch_size": batch["input_ids"].shape[0]} 111 | return {"loss": loss, "log": logs} 112 | 113 | def training_epoch_end(self, outputs): 114 | pass 115 | 116 | @torch.no_grad() 117 | def _generative_step(self, batch: dict) -> dict: 118 | tik = datetime.now() 119 | 120 | outputs = self._step(batch) 121 | loss = outputs['loss'] 122 | 123 | tok = datetime.now() 124 | batch_gen_time = tok - tik 125 | base_metrics = {"loss": loss.item()} 126 | 127 | # update metric_names 128 | self.update_metric_names(base_metrics, update_flag=self.metric_names_update_flag) 129 | self.metric_names_update_flag = False 130 | base_metrics.update(batch_gen_time=batch_gen_time) 131 | 132 | return base_metrics 133 | 134 | def validation_step(self, batch, batch_idx) -> Dict: 135 | return self._generative_step(batch) 136 | 137 | def validation_epoch_end(self, outputs, prefix="val") -> Dict: 138 | self.step_count += 1 139 | generative_metrics = { 140 | name: np.array([x[name] for x in outputs]).mean() for name in self.metric_names 141 | } 142 | metric_val = ( 143 | torch.tensor(generative_metrics[self.val_metric]) 144 | ) 145 | val_metrics = {f"{prefix}_{k}": x for k, x in generative_metrics.items()} 146 | val_metrics["step_count"] = float(self.step_count) 147 | self.current_val_metrics = val_metrics 148 | self.log_dict(self.current_val_metrics, sync_dist=True) 149 | self.metrics[prefix].append(val_metrics) # callback writes this to self.metrics_save_path. 150 | print(f"Evaluation result: {val_metrics}") 151 | return { 152 | "log": val_metrics, 153 | f"{prefix}_loss": generative_metrics["loss"], 154 | f"{prefix}_{self.val_metric}": metric_val, 155 | } 156 | 157 | def test_step(self, batch, batch_idx): 158 | tik = datetime.now() 159 | outputs = self._step(batch) 160 | loss = outputs['loss'] 161 | logits = outputs['logits'] 162 | 163 | # Shift so that tokens < n predict n 164 | shift_logits = logits[..., :-1, :].contiguous() 165 | shift_labels = batch['input_ids'][..., 1:].contiguous() 166 | 167 | shift_logits_argmax = torch.argmax(shift_logits, dim=-1) 168 | logits_acc = (shift_logits_argmax == shift_labels).byte() 169 | 170 | logits_ppl = [] 171 | for log, lab in zip(shift_logits, shift_labels): 172 | logits_ppl.append(-torch.nn.functional.log_softmax(log, dim=-1).gather(1, lab.unsqueeze(-1)).squeeze(1)) 173 | 174 | logits_ppl = torch.stack(logits_ppl) 175 | 176 | tok = datetime.now() 177 | batch_gen_time = tok - tik 178 | base_metrics = {"loss": loss.item(), "logits_acc": logits_acc.cpu(), "logits_ppl": logits_ppl.cpu(), 179 | "shift_logits_argmax": shift_logits_argmax.cpu(), "shift_labels": shift_labels.cpu()} 180 | 181 | # update metric_names 182 | self.update_metric_names(base_metrics, update_flag=self.metric_names_update_flag) 183 | self.metric_names_update_flag = False 184 | base_metrics.update(batch_gen_time=batch_gen_time) 185 | 186 | return base_metrics 187 | 188 | def test_epoch_end(self, outputs): 189 | prefix = 'test' 190 | self.step_count += 1 191 | loss_metrics = {} 192 | 193 | for name in self.metric_names: 194 | if 'loss' in name: 195 | loss_metrics[name] = np.array([x[name] for x in outputs]).mean() 196 | else: 197 | loss_metrics[name] = [x[name].numpy().tolist() for x in outputs] 198 | 199 | metric_val = (torch.tensor(loss_metrics[self.val_metric])) 200 | 201 | val_metrics = {f"{prefix}_{k}": x for k, x in loss_metrics.items()} 202 | val_metrics["step_count"] = float(self.step_count) 203 | self.current_val_metrics = val_metrics 204 | self.metrics[prefix].append(val_metrics) # callback writes this to self.metrics_save_path. 205 | test_output = { 206 | "log": val_metrics, 207 | f"{prefix}_loss": loss_metrics["loss"], 208 | f"{prefix}_{self.val_metric}": metric_val, 209 | } 210 | if self.store_test_output: 211 | self.test_output = test_output 212 | return test_output 213 | 214 | def get_dataset(self, data_type): 215 | if data_type == 'train': 216 | dataset = self.train_dataset_class(self.hparams, self.tokenizer, data_type) 217 | else: 218 | dataset = self.test_dataset_class(self.hparams, self.tokenizer, data_type) 219 | # self.model.resize_token_embeddings(new_num_tokens=len(self.tokenizer)) 220 | return dataset 221 | 222 | def get_dataloader(self, data_type: str, batch_size: int, shuffle: bool = False) -> DataLoader: 223 | dataset = self.get_dataset(data_type) 224 | return DataLoader( 225 | dataset, 226 | batch_size=batch_size, 227 | collate_fn=dataset.collate_fn, 228 | shuffle=shuffle, 229 | num_workers=self.hparams.num_workers, 230 | pin_memory=True 231 | ) 232 | 233 | def train_dataloader(self) -> DataLoader: 234 | train_shuffle = True if self.hparams.overfit_batches == 0.0 else False 235 | if not train_shuffle: 236 | print(f"train_shuffle: {train_shuffle} overfit_batches: {self.hparams.overfit_batches}") 237 | return self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=train_shuffle) 238 | 239 | def val_dataloader(self) -> DataLoader: 240 | return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size, shuffle=False) 241 | 242 | def test_dataloader(self) -> DataLoader: 243 | return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size, shuffle=False) 244 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/src/modules/__init__.py -------------------------------------------------------------------------------- /src/modules/pl_callbacks.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | @Notes:. 5 | - pytorch_lightning 6 | https://pytorch-lightning.readthedocs.io/en/latest/starter/new-project.html 7 | - ModelCheckpoint 8 | If we want set ModelCheckpoint with save_top_k, we need set callback_metrics for trainer. 9 | - rank_zero_only 10 | Whether the value will be logged only on rank 0. This will prevent synchronization which 11 | would produce a deadlock as not all processes would perform this log call. 12 | """ 13 | import os 14 | import pytorch_lightning as pl 15 | import torch 16 | 17 | from pathlib import Path 18 | from pytorch_lightning.utilities import rank_zero_only, rank_zero_info 19 | from src.utils.file_utils import save_json 20 | 21 | 22 | # ================================== call_back classes ================================== 23 | class Seq2SeqLoggingCallback(pl.Callback): 24 | def on_batch_end(self, trainer, pl_module): 25 | lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} 26 | pl_module.logger.log_metrics(lrs) 27 | # pl_module.logger.log_metrics(trainer.callback_metrics, step=trainer.global_step) # TODO: uncomment it 28 | 29 | @rank_zero_only 30 | def _write_logs( 31 | self, trainer: pl.Trainer, pl_module: pl.LightningModule, file_prefix: str, save_generations=True 32 | ) -> None: 33 | print(f"***** {file_prefix} results at step {trainer.global_step:05d} *****") 34 | metrics = trainer.callback_metrics 35 | trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) 36 | # Log results 37 | output_dir = Path(pl_module.experiment_output_dir) 38 | if file_prefix == "test": 39 | results_file = output_dir / "test_results.txt" 40 | generations_file = output_dir / "test_generations.txt" 41 | else: 42 | results_file = output_dir / f"{file_prefix}_results/{trainer.global_step:05d}.txt" 43 | generations_file = output_dir / f"{file_prefix}_generations/{trainer.global_step:05d}.txt" 44 | 45 | results_file.parent.mkdir(exist_ok=True) 46 | generations_file.parent.mkdir(exist_ok=True) 47 | 48 | with open(results_file, "a+") as writer: 49 | for key in sorted(metrics): 50 | if key in ["log", "progress_bar", "preds"]: 51 | continue 52 | val = metrics[key] 53 | if isinstance(val, torch.Tensor): 54 | val = val.item() 55 | msg = f"{key}: {val:.6f}\n" 56 | writer.write(msg) 57 | 58 | if not save_generations: 59 | return 60 | 61 | if "preds" in metrics: 62 | content = "\n".join(metrics["preds"]) 63 | generations_file.open("w+").write(content) 64 | 65 | @rank_zero_only 66 | def on_train_start(self, trainer, pl_module): 67 | try: 68 | nparams = pl_module.model.model.num_parameters() 69 | except AttributeError: 70 | nparams = pl_module.model.num_parameters() 71 | 72 | # mp stands for million parameters 73 | params_stat = {"n_params": nparams, "mp": nparams / 1e6} 74 | trainer.logger.log_metrics(params_stat) 75 | print(f"Training is started! params statistics: {params_stat}") 76 | 77 | @rank_zero_only 78 | def on_train_end(self, trainer, pl_module): 79 | print("Training is done.") 80 | 81 | @rank_zero_only 82 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 83 | save_json(pl_module.metrics, pl_module.metrics_save_path) 84 | return self._write_logs(trainer, pl_module, "test") 85 | 86 | @rank_zero_only 87 | def on_validation_end(self, trainer: pl.Trainer, pl_module): 88 | save_json(pl_module.metrics, pl_module.metrics_save_path) 89 | 90 | 91 | class LoggingCallback(pl.Callback): 92 | def on_train_batch_end(self, trainer, pl_module): 93 | lr_scheduler = trainer.lr_schedulers[0]["scheduler"] 94 | lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())} 95 | pl_module.logger.log_metrics(lrs) 96 | 97 | def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 98 | rank_zero_info("***** Validation results *****") 99 | metrics = trainer.callback_metrics 100 | pl_module.logger.log_metrics(metrics) 101 | # Log results 102 | for key in sorted(metrics): 103 | if key not in ["log", "progress_bar"]: 104 | rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) 105 | 106 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 107 | rank_zero_info("***** Test results *****") 108 | metrics = trainer.callback_metrics 109 | # Log and save results to file 110 | output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") 111 | with open(output_test_results_file, "w") as writer: 112 | for key in sorted(metrics): 113 | if key not in ["log", "progress_bar"]: 114 | rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) 115 | writer.write("{} = {}\n".format(key, str(metrics[key]))) 116 | 117 | 118 | class Seq2SeqCheckpointCallback(pl.callbacks.ModelCheckpoint): 119 | available_metrics = ["val_rouge2", "val_bleu", "val_loss"] 120 | 121 | def __init__(self, output_dir, experiment_name, monitor="val_loss", 122 | save_top_k=1, every_n_train_steps=1000, verbose=False, **kwargs): 123 | self.output_dir = output_dir 124 | self.experiment_name = experiment_name 125 | self.monitor = monitor 126 | self.save_top_k = save_top_k 127 | self.every_n_train_steps = every_n_train_steps 128 | self.verbose = verbose 129 | self.check_monitor_validity(self.monitor) 130 | super(Seq2SeqCheckpointCallback, self).__init__(dirpath=self.output_dir, 131 | filename=f"{self.experiment_name}" + 132 | '-{epoch:02d}-{step}-{' + 133 | f"{self.monitor}" + ':.4f}', 134 | auto_insert_metric_name=True, 135 | every_n_train_steps=self.every_n_train_steps, 136 | verbose=self.verbose, 137 | monitor=self.monitor, 138 | mode="min", 139 | save_top_k=self.save_top_k, 140 | **kwargs) 141 | 142 | def check_monitor_validity(self, monitor): 143 | """Saves the best model by validation ROUGE2 score.""" 144 | if monitor in self.available_metrics: 145 | pass 146 | else: 147 | raise NotImplementedError( 148 | f"seq2seq callbacks only support {self.available_metrics}, got {monitor}, " 149 | f"You can make your own by adding to this function." 150 | ) 151 | 152 | 153 | class SaveCheckpointEveryEpoch(pl.callbacks.ModelCheckpoint): 154 | """save checkpoint after each training epoch without validation. 155 | if ``last_k == -1``, all models are saved. and no monitor needed in this condition. 156 | otherwise, please log ``global_step`` in the training_step. e.g. self.log('global_step', self.global_step) 157 | 158 | :param last_k: the latest k models will be saved. 159 | :param save_weights_only: if ``True``, only the model's weights will be saved, 160 | else the full model is saved. 161 | """ 162 | 163 | def __init__(self, last_k, save_weights_only, output_dir, experiment_name, **kwargs): 164 | if last_k == -1: 165 | super().__init__(save_top_k=-1, save_last=False, save_weights_only=save_weights_only, dirpath=output_dir, 166 | filename=f"{experiment_name}" + '-{epoch:02d}-{step}', **kwargs) # TODO add {val_loss} 167 | else: 168 | super().__init__(monitor='step', mode='max', save_top_k=last_k, 169 | save_last=False, save_weights_only=save_weights_only, dirpath=output_dir, 170 | filename=f"{experiment_name}" + '-{epoch:02d}-{step}', 171 | save_on_train_epoch_end=True, **kwargs) 172 | -------------------------------------------------------------------------------- /src/modules/task/conversation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Conversation prompt template. 3 | 4 | Now we support 5 | - Vicuna 6 | - Koala 7 | - OpenAssistant/oasst-sft-1-pythia-12b 8 | - StabilityAI/stablelm-tuned-alpha-7b 9 | - databricks/dolly-v2-12b 10 | - THUDM/chatglm-6b 11 | - project-baize/baize-lora-7B 12 | - Alpaca/LLaMa 13 | """ 14 | 15 | import dataclasses 16 | from enum import auto, Enum 17 | from typing import List, Any 18 | 19 | 20 | class SeparatorStyle(Enum): 21 | """Different separator style.""" 22 | 23 | SINGLE = auto() 24 | TWO = auto() 25 | DOLLY = auto() 26 | OASST_PYTHIA = auto() 27 | BAIZE = auto() 28 | 29 | 30 | @dataclasses.dataclass 31 | class Conversation: 32 | """A class that keeps all conversation history.""" 33 | 34 | system: str 35 | roles: List[str] 36 | messages: List[List[str]] 37 | offset: int 38 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 39 | sep: str = "###" 40 | sep2: str = None 41 | 42 | # Used for the state in the gradio servers. 43 | conv_id: Any = None 44 | skip_next: bool = False 45 | model_name: str = None 46 | 47 | def get_prompt(self): 48 | if self.sep_style == SeparatorStyle.SINGLE: 49 | ret = self.system 50 | for role, message in self.messages: 51 | if message: 52 | ret += self.sep + " " + role + ": " + message 53 | else: 54 | ret += self.sep + " " + role + ":" 55 | return ret 56 | elif self.sep_style == SeparatorStyle.TWO: 57 | seps = [self.sep, self.sep2] 58 | ret = self.system + seps[0] 59 | for i, (role, message) in enumerate(self.messages): 60 | if message: 61 | ret += role + ": " + message + seps[i % 2] 62 | else: 63 | ret += role + ":" 64 | return ret 65 | elif self.sep_style == SeparatorStyle.DOLLY: 66 | seps = [self.sep, self.sep2] 67 | ret = self.system 68 | for i, (role, message) in enumerate(self.messages): 69 | if message: 70 | ret += role + ":\n" + message + seps[i % 2] 71 | if i % 2 == 1: 72 | ret += "\n\n" 73 | else: 74 | ret += role + ":\n" 75 | return ret 76 | elif self.sep_style == SeparatorStyle.OASST_PYTHIA: 77 | ret = self.system 78 | for role, message in self.messages: 79 | if message: 80 | ret += role + message + self.sep 81 | else: 82 | ret += role 83 | return ret 84 | elif self.sep_style == SeparatorStyle.BAIZE: 85 | ret = self.system 86 | for role, message in self.messages: 87 | if message: 88 | ret += "\n" + role + message 89 | else: 90 | ret += "\n" + role 91 | return ret 92 | else: 93 | raise ValueError(f"Invalid style: {self.sep_style}") 94 | 95 | def append_message(self, role, message): 96 | self.messages.append([role, message]) 97 | 98 | def to_gradio_chatbot(self): 99 | ret = [] 100 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 101 | if i % 2 == 0: 102 | ret.append([msg, None]) 103 | else: 104 | ret[-1][-1] = msg 105 | return ret 106 | 107 | def copy(self): 108 | return Conversation( 109 | system=self.system, 110 | roles=self.roles, 111 | messages=[[x, y] for x, y in self.messages], 112 | offset=self.offset, 113 | sep_style=self.sep_style, 114 | sep=self.sep, 115 | sep2=self.sep2, 116 | conv_id=self.conv_id, 117 | model_name=self.model_name, 118 | ) 119 | 120 | def dict(self): 121 | return { 122 | "system": self.system, 123 | "roles": self.roles, 124 | "messages": self.messages, 125 | "offset": self.offset, 126 | "sep": self.sep, 127 | "sep2": self.sep2, 128 | "conv_id": self.conv_id, 129 | "model_name": self.model_name, 130 | } 131 | 132 | 133 | conv_one_shot = Conversation( 134 | system="A chat between a curious human and an artificial intelligence assistant. " 135 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 136 | roles=("Human", "Assistant"), 137 | messages=( 138 | ( 139 | "Human", 140 | "What are the key differences between renewable and non-renewable energy sources?", 141 | ), 142 | ( 143 | "Assistant", 144 | "Renewable energy sources are those that can be replenished naturally in a relatively " 145 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " 146 | "Non-renewable energy sources, on the other hand, are finite and will eventually be " 147 | "depleted, such as coal, oil, and natural gas. Here are some key differences between " 148 | "renewable and non-renewable energy sources:\n" 149 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " 150 | "energy sources are finite and will eventually run out.\n" 151 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact " 152 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " 153 | "and other negative effects.\n" 154 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " 155 | "have lower operational costs than non-renewable sources.\n" 156 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " 157 | "locations than non-renewable sources.\n" 158 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " 159 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n" 160 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " 161 | "non-renewable sources are not, and their depletion can lead to economic and social instability.", 162 | ), 163 | ), 164 | offset=2, 165 | sep_style=SeparatorStyle.SINGLE, 166 | sep="###", 167 | ) 168 | 169 | conv_vicuna_v1_1 = Conversation( 170 | system="A chat between a curious user and an artificial intelligence assistant. " 171 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 172 | roles=("USER", "ASSISTANT"), 173 | messages=(), 174 | offset=0, 175 | sep_style=SeparatorStyle.TWO, 176 | sep=" ", 177 | sep2="", 178 | ) 179 | 180 | conv_koala_v1 = Conversation( 181 | system="BEGINNING OF CONVERSATION:", 182 | roles=("USER", "GPT"), 183 | messages=(), 184 | offset=0, 185 | sep_style=SeparatorStyle.TWO, 186 | sep=" ", 187 | sep2="", 188 | ) 189 | 190 | conv_dolly = Conversation( 191 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n", 192 | roles=("### Instruction", "### Response"), 193 | messages=(), 194 | offset=0, 195 | sep_style=SeparatorStyle.DOLLY, 196 | sep="\n\n", 197 | sep2="### End", 198 | ) 199 | 200 | conv_oasst = Conversation( 201 | system="", 202 | roles=("<|prompter|>", "<|assistant|>"), 203 | messages=(), 204 | offset=0, 205 | sep_style=SeparatorStyle.OASST_PYTHIA, 206 | sep="<|endoftext|>", 207 | ) 208 | 209 | conv_stablelm = Conversation( 210 | system="""<|SYSTEM|># StableLM Tuned (Alpha version) 211 | - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. 212 | - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. 213 | - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. 214 | - StableLM will refuse to participate in anything that could harm a human. 215 | """, 216 | roles=("<|USER|>", "<|ASSISTANT|>"), 217 | messages=(), 218 | offset=0, 219 | sep_style=SeparatorStyle.OASST_PYTHIA, 220 | sep="", 221 | ) 222 | 223 | conv_baize = Conversation( 224 | system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.", 225 | roles=("[|Human|]", "[|AI|]"), 226 | messages=( 227 | ("[|Human|]", "Hello!"), 228 | ("[|AI|]", "Hi!"), 229 | ), 230 | offset=2, 231 | sep_style=SeparatorStyle.BAIZE, 232 | sep="[|Human|]", 233 | ) 234 | 235 | conv_templates = { 236 | "conv_one_shot": conv_one_shot, 237 | "vicuna_v1.1": conv_vicuna_v1_1, 238 | "koala_v1": conv_koala_v1, 239 | "dolly": conv_dolly, 240 | "oasst": conv_oasst, 241 | "baize": conv_baize, 242 | } 243 | 244 | 245 | def get_default_conv_template(model_name): 246 | model_name = model_name.lower() 247 | if "vicuna" in model_name or "output" in model_name: 248 | return conv_vicuna_v1_1 249 | elif "koala" in model_name: 250 | return conv_koala_v1 251 | elif "dolly-v2" in model_name: 252 | return conv_dolly 253 | elif "oasst" in model_name and "pythia" in model_name: 254 | return conv_oasst 255 | elif "baize" in model_name: 256 | return conv_baize 257 | elif "stablelm" in model_name: 258 | return conv_stablelm 259 | return conv_one_shot 260 | 261 | 262 | def compute_skip_echo_len(model_name, conv, prompt): 263 | model_name = model_name.lower() 264 | if "chatglm" in model_name: 265 | skip_echo_len = len(conv.messages[-2][1]) + 1 266 | elif "dolly-v2" in model_name: 267 | special_toks = ["### Instruction:", "### Response:", "### End"] 268 | skip_echo_len = len(prompt) 269 | for tok in special_toks: 270 | skip_echo_len -= prompt.count(tok) * len(tok) 271 | elif "oasst" in model_name and "pythia" in model_name: 272 | special_toks = ["<|prompter|>", "<|assistant|>", "<|endoftext|>"] 273 | skip_echo_len = len(prompt) 274 | for tok in special_toks: 275 | skip_echo_len -= prompt.count(tok) * len(tok) 276 | elif "stablelm" in model_name: 277 | special_toks = ["<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>"] 278 | skip_echo_len = len(prompt) 279 | for tok in special_toks: 280 | skip_echo_len -= prompt.count(tok) * len(tok) 281 | elif "baize" in model_name: 282 | skip_echo_len = len(prompt) 283 | else: 284 | skip_echo_len = len(prompt) + 1 - prompt.count("") * 3 285 | return skip_echo_len 286 | 287 | 288 | if __name__ == "__main__": 289 | default_conversation = conv_templates["vicuna_v1.1"] 290 | print(default_conversation.get_prompt()) 291 | -------------------------------------------------------------------------------- /src/modules/task/data_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | - transformers examples for using BART model 5 | https://github.com/huggingface/transformers/tree/master/examples/pytorch/summarization 6 | https://discuss.huggingface.co/t/train-bart-for-conditional-generation-e-g-summarization/1904/2 7 | - add_special_tokens 8 | https://huggingface.co/docs/transformers/v4.17.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase 9 | - linecache 10 | https://blog.csdn.net/my2010Sam/article/details/38022041 11 | - torch Dataset 12 | https://pytorch.org/docs/stable/data.html?highlight=dataset#torch.utils.data.Dataset 13 | @Notes: 14 | - add_special_tokens 15 | special_tokens_dict (dictionary str to str or tokenizers.AddedToken) — 16 | Keys should be in the list of predefined special attributes: 17 | [bos_token, eos_token, unk_token, sep_token, pad_token, cls_token, mask_token, additional_special_tokens]. 18 | Tokens are only added if they are not already in the vocabulary (tested by checking 19 | if the tokenizer assign the index of the unk_token to them). 20 | - collate_fn 21 | A custom collate_fn can be used to customize collation, e.g., padding sequential data to max length of a batch. 22 | See this section on more about collate_fn. 23 | """ 24 | 25 | import os 26 | import copy 27 | import transformers 28 | import torch 29 | 30 | from tqdm import tqdm 31 | from dataclasses import dataclass 32 | from typing import Dict, Sequence 33 | from torch.utils.data import Dataset 34 | from datasets import load_dataset 35 | from src.utils.print_utils import print_rank_0 36 | 37 | IGNORE_INDEX = -100 38 | 39 | 40 | class PretrainingDataset(Dataset): 41 | def __init__(self, config, tokenizer, data_type): 42 | self.config = config 43 | self.tokenizer = tokenizer 44 | self.data = load_dataset('json', data_files=os.path.join(config.data_dir, f'{data_type}.json'))['train'] 45 | self.print_dataset_example(example=self[0]) 46 | 47 | def collate_fn(self, batch): 48 | input_ids = torch.tensor([sample["input_ids"] for sample in batch]) 49 | return {"input_ids": input_ids, "labels": input_ids} 50 | 51 | def print_dataset_example(self, example): 52 | print("input_ids", example["input_ids"]) 53 | print("inputs", self.tokenizer.decode(example["input_ids"])) 54 | 55 | def __len__(self): 56 | return len(self.data) 57 | 58 | def __getitem__(self, index): 59 | input_ids = self.data[index]['token'] 60 | return {'input_ids': input_ids} 61 | 62 | 63 | class SupervisedDataset(Dataset): 64 | def __init__(self, config, tokenizer, data_type): 65 | self.config = config 66 | self.tokenizer = tokenizer 67 | self.data = load_dataset('json', data_files=os.path.join(config.data_dir, f'{data_type}.json'))['train'] 68 | self.data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 69 | print_rank_0(self.data[0]) 70 | 71 | def _add_speaker_and_signal(self, conversation): 72 | meta_instruction = ("一位用户和法律大模型韩非之间的对话。" 73 | "对于用户的法律咨询,韩非给出准确的、详细的、温暖的指导建议。" 74 | "对于用户的指令问题,韩非给出有益的、详细的、有礼貌的回答。\n\n") 75 | conversation_roles = ("用户", "韩非") 76 | 77 | def tokenize(prompt): 78 | result = self.tokenizer(prompt, max_length=self.config.max_source_length, 79 | truncation=True, padding=False) 80 | return {"input_ids": result["input_ids"], "labels": copy.deepcopy(result["input_ids"])} 81 | 82 | user_sep, sys_sep = " ", self.tokenizer.eos_token if self.tokenizer.eos_token else "" 83 | input_ids = tokenize(meta_instruction + user_sep)['input_ids'] # NOTE: + user_sep 84 | labels = [IGNORE_INDEX] * len(input_ids) 85 | for turn in conversation: 86 | if turn["from"].lower() == "gpt": 87 | role = conversation_roles[1] 88 | sent = tokenize(role + ":" + turn["value"] + (sys_sep if turn["value"] else "")) 89 | input_ids += sent['input_ids'] 90 | labels += sent['labels'] 91 | else: 92 | role = conversation_roles[0] 93 | sent = tokenize(role + ":" + turn["value"] + (user_sep if turn["value"] else "")) 94 | input_ids += sent['input_ids'] 95 | labels += [IGNORE_INDEX] * len(sent['labels']) 96 | input_ids = torch.tensor(input_ids[:self.config.max_source_length]) 97 | labels = torch.tensor(labels[:self.config.max_source_length]) 98 | return input_ids, labels 99 | 100 | def conversation_data(self, data_point): 101 | conversation = data_point['conversations'] 102 | input_ids, labels = self._add_speaker_and_signal(conversation) 103 | return {'input_ids': input_ids, 'labels': labels} 104 | 105 | def preprocess(self, data_point): 106 | return self.conversation_data(data_point) 107 | 108 | def collate_fn(self, batch): 109 | processed_batch = [self.preprocess(x) for x in batch] 110 | return self.data_collator(processed_batch) 111 | 112 | def __len__(self): 113 | return len(self.data) 114 | 115 | def __getitem__(self, index): 116 | return self.data[index] 117 | 118 | 119 | @dataclass 120 | class DataCollatorForSupervisedDataset(object): 121 | """Collate examples for supervised fine-tuning.""" 122 | 123 | tokenizer: transformers.PreTrainedTokenizer 124 | 125 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 126 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 127 | input_ids = torch.nn.utils.rnn.pad_sequence( 128 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 129 | ) 130 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 131 | return dict( 132 | input_ids=input_ids, 133 | labels=labels, 134 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 135 | ) 136 | 137 | -------------------------------------------------------------------------------- /src/modules/task/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import transformers 4 | 5 | from transformers.trainer_pt_utils import LabelSmoother 6 | from torch.utils.data import Dataset 7 | from typing import Dict 8 | from conversation import get_default_conv_template, SeparatorStyle 9 | from src.utils.print_utils import print_rank_0 10 | 11 | 12 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 13 | 14 | 15 | def preprocess(sources, tokenizer: transformers.PreTrainedTokenizer) -> Dict: 16 | conv = get_default_conv_template("vicuna").copy() 17 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 18 | 19 | # Apply prompt templates 20 | conversations = [] 21 | for i, source in enumerate(sources): 22 | if roles[source[0]["from"]] != conv.roles[0]: 23 | # Skip the first one if it is not from human 24 | source = source[1:] 25 | 26 | conv.messages = [] 27 | for j, sentence in enumerate(source): 28 | role = roles[sentence["from"]] 29 | assert role == conv.roles[j % 2], f"{i}" 30 | conv.append_message(role, sentence["value"]) 31 | conversations.append(conv.get_prompt()) 32 | 33 | # Tokenize conversations 34 | input_ids = tokenizer( 35 | conversations, 36 | return_tensors="pt", 37 | padding="max_length", 38 | max_length=tokenizer.model_max_length, 39 | truncation=True, 40 | ).input_ids 41 | targets = input_ids.clone() 42 | 43 | assert conv.sep_style == SeparatorStyle.TWO 44 | 45 | # TODO Mask targets 46 | # sep = conv.sep + conv.roles[1] + ": " 47 | # for conversation, target in zip(conversations, targets): 48 | # total_len = int(target.ne(tokenizer.pad_token_id).sum()) 49 | # 50 | # rounds = conversation.split(conv.sep2) 51 | # cur_len = 1 52 | # target[:cur_len] = IGNORE_TOKEN_ID 53 | # for i, rou in enumerate(rounds): 54 | # if rou == "": 55 | # break 56 | # 57 | # parts = rou.split(sep) 58 | # if len(parts) != 2: 59 | # break 60 | # parts[0] += sep 61 | # round_len = len(tokenizer(rou).input_ids) 62 | # instruction_len = len(tokenizer(parts[0]).input_ids) - 2 63 | # 64 | # target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID 65 | # 66 | # cur_len += round_len 67 | # target[cur_len:] = IGNORE_TOKEN_ID 68 | # 69 | # if False: 70 | # z = target.clone() 71 | # z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) 72 | # rank0_print(tokenizer.decode(z)) 73 | # 74 | # if cur_len < tokenizer.model_max_length: 75 | # if cur_len != total_len: 76 | # target[:] = IGNORE_TOKEN_ID 77 | # rank0_print( 78 | # f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 79 | # f" (ignored)" 80 | # ) 81 | 82 | return dict( 83 | input_ids=input_ids, 84 | labels=targets, 85 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 86 | ) 87 | 88 | 89 | class SupervisedDataset(Dataset): 90 | """Dataset for supervised fine-tuning.""" 91 | 92 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): 93 | super(SupervisedDataset, self).__init__() 94 | 95 | print_rank_0("Formatting inputs...") 96 | sources = [example["conversations"] for example in raw_data] 97 | data_dict = preprocess(sources, tokenizer) 98 | 99 | self.input_ids = data_dict["input_ids"] 100 | self.labels = data_dict["labels"] 101 | self.attention_mask = data_dict["attention_mask"] 102 | 103 | def __len__(self): 104 | return len(self.input_ids) 105 | 106 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 107 | return dict( 108 | input_ids=self.input_ids[i], 109 | labels=self.labels[i], 110 | attention_mask=self.attention_mask[i], 111 | ) 112 | 113 | 114 | class LazySupervisedDataset(Dataset): 115 | """Dataset for supervised fine-tuning.""" 116 | 117 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): 118 | super(LazySupervisedDataset, self).__init__() 119 | self.tokenizer = tokenizer 120 | 121 | print_rank_0("Formatting inputs...Skip in lazy mode") 122 | self.tokenizer = tokenizer 123 | self.raw_data = raw_data 124 | self.cached_data_dict = {} 125 | 126 | def __len__(self): 127 | return len(self.raw_data) 128 | 129 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 130 | if i in self.cached_data_dict: 131 | return self.cached_data_dict[i] 132 | 133 | ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer) 134 | ret = dict( 135 | input_ids=ret["input_ids"][0], 136 | labels=ret["labels"][0], 137 | attention_mask=ret["attention_mask"][0], 138 | ) 139 | self.cached_data_dict[i] = ret 140 | 141 | return ret 142 | 143 | 144 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_path, lazy_preprocess) -> Dict: 145 | """Make dataset and collator for supervised fine-tuning.""" 146 | dataset_cls = (LazySupervisedDataset if lazy_preprocess else SupervisedDataset) 147 | print_rank_0("Loading data...") 148 | raw_data = json.load(open(data_path, "r")) 149 | print_rank_0(f"#data {len(raw_data)}") 150 | dataset = dataset_cls(raw_data, tokenizer=tokenizer) 151 | return dict(dataset=dataset) 152 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.18.0 2 | aiofiles==23.1.0 3 | aiohttp==3.8.4 4 | aiosignal==1.3.1 5 | altair==4.2.2 6 | anyio==3.6.2 7 | async-timeout==4.0.2 8 | attrs==23.1.0 9 | blessed==1.20.0 10 | brotlipy==0.7.0 11 | cachetools==5.3.0 12 | certifi==2022.12.7 13 | cffi==1.15.1 14 | charset-normalizer==2.0.4 15 | click==8.1.3 16 | contourpy==1.0.7 17 | cryptography==39.0.1 18 | cycler==0.11.0 19 | entrypoints==0.4 20 | fastapi==0.95.1 21 | ffmpy==0.3.0 22 | filelock==3.12.0 23 | fonttools==4.39.3 24 | frozenlist==1.3.3 25 | fsspec==2023.4.0 26 | gpustat==1.1 27 | gradio==3.24.1 28 | gradio_client==0.1.4 29 | h11==0.14.0 30 | httpcore==0.17.0 31 | httpx==0.24.0 32 | huggingface-hub==0.14.1 33 | idna==3.4 34 | importlib-resources==5.12.0 35 | Jinja2==3.1.2 36 | jsonschema==4.17.3 37 | kiwisolver==1.4.4 38 | linkify-it-py==2.0.2 39 | markdown-it-py==2.2.0 40 | MarkupSafe==2.1.2 41 | matplotlib==3.7.1 42 | mdit-py-plugins==0.3.3 43 | mdurl==0.1.2 44 | mkl-fft==1.3.6 45 | mkl-random==1.2.2 46 | mkl-service==2.4.0 47 | multidict==6.0.4 48 | numpy==1.24.3 49 | nvidia-ml-py==11.525.112 50 | nvitop==1.1.2 51 | orjson==3.8.11 52 | packaging==23.1 53 | pandas==2.0.1 54 | Pillow==9.4.0 55 | pip==23.0.1 56 | psutil==5.9.5 57 | pycparser==2.21 58 | pydantic==1.10.7 59 | pydub==0.25.1 60 | pyOpenSSL==23.0.0 61 | pyparsing==3.0.9 62 | pyrsistent==0.19.3 63 | PySocks==1.7.1 64 | python-dateutil==2.8.2 65 | python-multipart==0.0.6 66 | pytz==2023.3 67 | PyYAML==6.0 68 | regex==2023.5.5 69 | requests==2.29.0 70 | semantic-version==2.10.0 71 | setuptools==66.0.0 72 | six==1.16.0 73 | sniffio==1.3.0 74 | starlette==0.26.1 75 | termcolor==2.3.0 76 | tokenizers==0.13.3 77 | toolz==0.12.0 78 | torch==1.12.1 79 | torchaudio==0.12.1 80 | torchvision==0.13.1 81 | tqdm==4.65.0 82 | transformers==4.28.1 83 | typing_extensions==4.5.0 84 | tzdata==2023.3 85 | uc-micro-py==1.0.2 86 | urllib3==1.26.15 87 | uvicorn==0.22.0 88 | wcwidth==0.2.6 89 | websockets==11.0.2 90 | wheel==0.38.4 91 | yarl==1.9.2 92 | zipp==3.15.0 93 | -------------------------------------------------------------------------------- /src/utils/230525模型对比表.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/src/utils/230525模型对比表.xlsx -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/compare_score_count.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/src/utils/compare_score_count.png -------------------------------------------------------------------------------- /src/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def check_ratio_valid(val_ratio: float, test_radio: float) -> bool: 6 | return (val_ratio >= 0.0) & (test_radio >= 0.0) & (val_ratio + test_radio < 1.0) 7 | 8 | 9 | def split_data(data: list, val_ratio: float = 0.05, test_ratio: float = 0.05): 10 | if not check_ratio_valid(val_ratio, test_ratio): 11 | print(f'Error: radios: {val_ratio}, {test_ratio} not valid') 12 | # set random seed 13 | np.random.seed = 42 14 | data_size = len(data) 15 | shuffiled_dices = np.random.permutation(data_size) 16 | val_split_pos = int(data_size * val_ratio) 17 | test_split_pos = val_split_pos + int(data_size * test_ratio) 18 | data = pd.Series(data) 19 | _val = data.iloc[shuffiled_dices[:val_split_pos]].values.tolist() 20 | _test = data.iloc[shuffiled_dices[val_split_pos:test_split_pos]].values.tolist() 21 | _train = data.iloc[shuffiled_dices[test_split_pos:]].values.tolist() 22 | return _train, _val, _test 23 | -------------------------------------------------------------------------------- /src/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | """ 5 | 6 | import os 7 | import pickle 8 | import shutil 9 | import json 10 | import linecache 11 | from pathlib import Path 12 | import joblib 13 | 14 | 15 | def save_json(content, path, indent=4, **json_dump_kwargs): 16 | with open(path, "w", encoding="utf-8") as fw: 17 | json.dump(str(content), fw, indent=indent, **json_dump_kwargs) 18 | 19 | 20 | def load_json(path): 21 | with open(path, "r", encoding="utf-8") as fr: 22 | return json.load(fr) 23 | 24 | 25 | def copy_file_or_dir(source_path, target_path): 26 | shutil.copy(source_path, target_path) 27 | 28 | 29 | def output_obj_to_file(obj, file_path): 30 | with open(file_path, 'w', encoding='utf-8') as fw: 31 | fw.write(str(obj)) 32 | 33 | 34 | def output_long_text_to_file(long_text, file_path, delimiters=None): 35 | with open(file_path, 'w', encoding='utf-8') as fw: 36 | long_text = str(long_text) 37 | if delimiters is None: 38 | delimiters = [',', '.', ';', '!', '?'] 39 | elif isinstance(delimiters, list): 40 | pass 41 | elif isinstance(delimiters, str): 42 | delimiters = [delimiters] 43 | else: 44 | raise ValueError 45 | 46 | for punc in delimiters: 47 | long_text.replace(punc, punc + '\n') 48 | fw.write(long_text) 49 | 50 | 51 | def file_to_lines(data_file: str): 52 | if not os.path.exists(data_file): 53 | print(f'Error: file_path {data_file} does not existes') 54 | with open(data_file, 'r', encoding='utf-8') as fr: 55 | return fr.readlines() 56 | 57 | 58 | def lines_to_file(lines: list, data_file: str): 59 | _dir = os.path.dirname(data_file) 60 | if not os.path.exists(_dir): 61 | print(f'Error: file directory {_dir} does not existes') 62 | with open(data_file, 'w', encoding='utf-8') as fw: 63 | return fw.writelines(lines) 64 | 65 | 66 | def pickle_load(path): 67 | """pickle.load(path)""" 68 | with open(path, "rb") as f: 69 | return pickle.load(f) 70 | 71 | 72 | def pickle_save(obj, path): 73 | """pickle.dump(obj, path)""" 74 | with open(path, "wb") as f: 75 | return pickle.dump(obj, f) 76 | 77 | 78 | def joblib_load(path): 79 | """joblib.load(path)""" 80 | with open(path, "rb") as f: 81 | return joblib.load(f) 82 | 83 | 84 | def joblib_save(obj, path): 85 | """joblib.dump(obj, path)""" 86 | with open(path, "wb") as f: 87 | return joblib.dump(obj, f) 88 | 89 | 90 | def get_line_from_file(file_path, index): 91 | file_path = Path(file_path) 92 | if file_path.exists(): 93 | raise FileNotFoundError(f"{file_path} not existing.") 94 | return linecache.getline(file_path, index) 95 | -------------------------------------------------------------------------------- /src/utils/gen_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | - torch.topk 5 | https://pytorch.org/docs/stable/generated/torch.topk.html 6 | - torch.multinomial 7 | https://pytorch.org/docs/stable/generated/torch.multinomial.html 8 | """ 9 | 10 | import torch 11 | 12 | 13 | def gather_nd(x, indices): 14 | newshape = list(indices.shape[:-1] + x.shape[indices.shape[-1]:]) + [1] 15 | indices = indices.view(-1, indices.shape[-1]).tolist() 16 | out = torch.cat([torch.tensor([x.__getitem__(tuple(i))]) for i in indices]).reshape(tuple(newshape)) 17 | return out 18 | 19 | 20 | def top_p_logits(logits, p, device=None): 21 | """Nucleus sampling""" 22 | if device is None: 23 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 24 | batch, _ = logits.size() 25 | sorted_logits, _ = torch.sort(logits, descending=True, axis=-1) 26 | cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), axis=-1) 27 | cumulative_position = torch.sum((cumulative_probs <= p).to(torch.int32), axis=-1) - 1 28 | indices = torch.stack([ 29 | torch.arange(0, batch).to(device), 30 | # number of indices to include 31 | torch.max(cumulative_position, torch.zeros([batch], dtype=cumulative_position.dtype).to(device)), 32 | ], axis=-1) 33 | min_values = gather_nd(sorted_logits, indices).to(device) 34 | return torch.where( 35 | logits < min_values, 36 | torch.ones_like(logits) * -1e10, 37 | logits, 38 | ) 39 | 40 | 41 | def sample_sequence(input_ids, model, max_length, top_p=0.9, tokenizer=None, 42 | no_sample=False, device=None): 43 | if not tokenizer: 44 | raise ModuleNotFoundError("tokenizer needed") 45 | if device is None: 46 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 47 | batch_size = input_ids.size()[0] 48 | decoder_input_ids = torch.tensor([tokenizer.eos_token_id for _ in range(batch_size)])[:, None].to(device) 49 | for _ in range(max_length): 50 | outputs = model(input_ids, decoder_input_ids=decoder_input_ids, use_cache=False, return_dict=True) 51 | logits = outputs["logits"] 52 | logits = logits[:, -1, :] 53 | 54 | if no_sample: 55 | probs = torch.softmax(logits, dim=-1) 56 | prev = torch.topk(probs, 1).indices 57 | else: 58 | logits = top_p_logits(logits, p=top_p) 59 | probs = torch.softmax(logits, dim=-1) 60 | prev = torch.multinomial(probs, 1) 61 | decoder_input_ids = torch.cat([decoder_input_ids, prev], 1) 62 | # early stop 63 | if prev[:, 0].eq(tokenizer.eos_token_id).sum() == prev.shape[0]: 64 | break 65 | return decoder_input_ids 66 | 67 | 68 | def ids_to_clean_string(token_list, tokenizer, remain_sp_tokens=False): 69 | real_s = 0 70 | for index_, token_ in enumerate(token_list): 71 | if token_ not in [tokenizer.bos_token_id, tokenizer.eos_token_id]: 72 | real_s = index_ 73 | break 74 | token_list = token_list[real_s:] 75 | string = tokenizer.decode(token_list, skip_special_tokens=False) 76 | # string = string[:string.find(tokenizer.eos_token)].strip() 77 | if not remain_sp_tokens: # remove special tokens in output 78 | for one in tokenizer.all_special_tokens: 79 | string = string.replace(one, " ") 80 | string = " ".join([one for one in string.split()]) 81 | return string 82 | -------------------------------------------------------------------------------- /src/utils/huggingface_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | @Notes: 5 | """ 6 | from typing import List 7 | 8 | from huggingface_hub import hf_hub_url, snapshot_download 9 | 10 | 11 | def download_specific_file(repository_id: str = "lysandre/arxiv-nlp", 12 | filename: str = "config.json"): 13 | hf_hub_url(repo_id=repository_id, filename=filename) 14 | 15 | 16 | def download_repository(repository_id: str = "lysandre/arxiv-nlp", 17 | ignore_regex: List[str] = ["*.msgpack", "*.h5", "*.tflite"]): 18 | local_folder = snapshot_download(repo_id=repository_id, ignore_regex=ignore_regex) 19 | print(f"{repository_id} downloaded in {local_folder}") 20 | 21 | 22 | if __name__ == '__main__': 23 | download_repository(repository_id="bart-base") 24 | -------------------------------------------------------------------------------- /src/utils/metric_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchmetrics 3 | 4 | 5 | class AccumulatedLoss(torchmetrics.Metric): 6 | def __init__(self, dist_sync_on_step=True): 7 | # call `self.add_state`for every internal state that is needed for the metrics computations 8 | # dist_reduce_fx indicates the function that should be used to reduce 9 | # state from multiple processes 10 | super().__init__(dist_sync_on_step=dist_sync_on_step) 11 | 12 | self.add_state("loss", default=torch.tensor(0), dist_reduce_fx="sum") 13 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 14 | 15 | def update(self, loss: torch.Tensor): 16 | # update metric states 17 | self.loss += loss 18 | self.total += 1 19 | 20 | def compute(self): 21 | # compute final result 22 | return self.loss.float() / self.total 23 | 24 | 25 | if __name__ == '__main__': 26 | # class LitModel(LightningModule): 27 | # def __init__(self): 28 | # # 1. initialize the metric 29 | # self.accumulated_loss = AccumulatedLoss() 30 | # 31 | # def training_step(self, batch, batch_idx): 32 | # x, y = batch 33 | # loss = self(x) 34 | # 35 | # # 2. compute the metric 36 | # self.accumulated_loss(loss) 37 | # 38 | # # 3. log it 39 | # self.log("train_loss", self.accumulated_loss) 40 | # 41 | # def training_epoch_end(self, outputs): 42 | # # 4. reset the metric 43 | # self.accumulated_loss.reset() 44 | pass 45 | -------------------------------------------------------------------------------- /src/utils/nlg_eval_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | @Notes: 5 | """ 6 | 7 | import os 8 | from collections import defaultdict 9 | from typing import Callable, Dict, Iterable, List, Tuple, Union 10 | import numpy as np 11 | 12 | from rouge_score import rouge_scorer, scoring 13 | import nltk 14 | import re 15 | import string 16 | 17 | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"] 18 | 19 | 20 | def line_normalize(line: str): 21 | line = " ".join(line.strip().split()) 22 | return line 23 | 24 | 25 | def calculate_bleu(ref_lines, gen_lines, metrics: dict = None): 26 | if metrics is None: 27 | metrics = {} 28 | for bleu_i in range(1, 5): 29 | weights = tuple([1. / bleu_i for _ in range(bleu_i)]) 30 | metrics[f"bleu-{bleu_i}"] = round(nltk.translate.bleu_score.corpus_bleu( 31 | list_of_references=[[ref] for ref in ref_lines], 32 | hypotheses=gen_lines, 33 | weights=weights), 4) 34 | return metrics 35 | 36 | 37 | def extract_rouge_mid_statistics(dct): 38 | new_dict = {} 39 | for k1, v1 in dct.items(): 40 | mid = v1.mid 41 | new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]} 42 | return new_dict 43 | 44 | 45 | def calculate_rouge( 46 | pred_lines: List[str], 47 | tgt_lines: List[str], 48 | use_stemmer=True, 49 | rouge_keys=ROUGE_KEYS, 50 | return_precision_and_recall=False, 51 | bootstrap_aggregation=True, 52 | newline_sep=True, 53 | ) -> Dict: 54 | """Calculate rouge using rouge_scorer package. 55 | 56 | Args: 57 | pred_lines: list of summaries generated by model 58 | tgt_lines: list of groundtruth summaries (e.g. contents of val.target) 59 | use_stemmer: Bool indicating whether Porter stemmer should be used to 60 | strip word suffixes to improve matching. 61 | rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum 62 | return_precision_and_recall: (False) whether to also return precision and recall. 63 | bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False 64 | this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]`` 65 | newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL 66 | on multi sentence summaries (CNN/DM dataset). 67 | 68 | Returns: 69 | Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys 70 | 71 | """ 72 | scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer) 73 | aggregator = scoring.BootstrapAggregator() 74 | for pred, tgt in zip(tgt_lines, pred_lines): 75 | # rougeLsum expects "\n" separated sentences within a summary 76 | if newline_sep: 77 | pred = pred + "\n" 78 | tgt = tgt + "\n" 79 | scores = scorer.score(pred, tgt) 80 | aggregator.add_scores(scores) 81 | 82 | if bootstrap_aggregation: 83 | result = aggregator.aggregate() 84 | if return_precision_and_recall: 85 | return extract_rouge_mid_statistics(result) # here we return dict 86 | else: 87 | return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} 88 | 89 | else: 90 | return aggregator._scores # here we return defaultdict(list) 91 | 92 | 93 | def repetition_distinct_metric(gen_lines, metrics: dict = None, repetition_times=2): 94 | if metrics is None: 95 | metrics = {} 96 | 97 | for gram_n in range(1, 5): 98 | repetition_count = 0 99 | all_ngram = defaultdict(int) 100 | all_ngram_num = 0 101 | for gen_idx, line in enumerate(gen_lines): 102 | n_grams = ["_".join(gram) for gram in nltk.ngrams(line, n=gram_n)] 103 | all_ngram_num += len(n_grams) 104 | # for distinct 105 | for gram in n_grams: 106 | all_ngram[gram] += 1 107 | # for repetition 108 | for gram in set(n_grams): 109 | if n_grams.count(gram) >= repetition_times: 110 | repetition_count += 1 111 | break 112 | metrics[f"repetition-{gram_n}"] = "%.4f" % (repetition_count / float(len(gen_lines))) 113 | metrics[f"distinct-{gram_n}"] = "%.4f" % (len(all_ngram) / float(all_ngram_num)) 114 | return metrics 115 | 116 | 117 | def normalize_answer(s): 118 | """Lower text and remove punctuation, articles and extra whitespace.""" 119 | 120 | def remove_articles(text): 121 | return re.sub(r'\b(a|an|the)\b', ' ', text) 122 | 123 | def white_space_fix(text): 124 | return ' '.join(text.split()) 125 | 126 | def remove_punc(text): 127 | exclude = set(string.punctuation) 128 | return ''.join(ch for ch in text if ch not in exclude) 129 | 130 | def lower(text): 131 | return text.lower() 132 | 133 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 134 | 135 | 136 | def exact_match_score(prediction, ground_truth): 137 | return normalize_answer(prediction) == normalize_answer(ground_truth) 138 | 139 | 140 | def metric_max_over_ground_truths(metric_fn, predictions: Union[str, List[str]], ground_truths: List[str]): 141 | scores_for_ground_truths = [] 142 | 143 | if isinstance(predictions, str): 144 | predictions = [predictions] 145 | 146 | for prediction in predictions: 147 | for ground_truth in ground_truths: 148 | score = metric_fn(prediction, ground_truth) 149 | scores_for_ground_truths.append(score) 150 | 151 | return max(scores_for_ground_truths) 152 | -------------------------------------------------------------------------------- /src/utils/print_utils.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.utilities import rank_zero_only 2 | 3 | 4 | @rank_zero_only 5 | def print_rank_0(*args): 6 | print(*args) 7 | -------------------------------------------------------------------------------- /src/utils/string_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | @Notes: 5 | 6 | """ 7 | 8 | 9 | def str2bool(v): 10 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 11 | return True 12 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 13 | return False 14 | else: 15 | import argparse 16 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 17 | 18 | 19 | def are_same_strings(string1: str, string2: str): 20 | if not isinstance(string1, str) or not isinstance(string2, str): 21 | raise ValueError("input should be strings") 22 | return string1.strip().lower() == string2.strip().lower() 23 | 24 | 25 | def rm_extra_spaces(string: str): 26 | return " ".join(string.strip().split()) 27 | -------------------------------------------------------------------------------- /src/utils/task/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/src/utils/task/__init__.py -------------------------------------------------------------------------------- /src/utils/task/event_analyzer_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from multiprocessing import Process 4 | 5 | import spacy 6 | from typing import List 7 | from pathlib import Path 8 | from tqdm import tqdm 9 | import pandas as pd 10 | from collections import Counter 11 | 12 | from src.configuration.constants import BASE_DIR 13 | from preprocessing.event_trigger.event_predictor import EventPredictor 14 | 15 | 16 | class EventAnalyzer(object): 17 | def event_trainings_nums(self, event_predictor: EventPredictor): 18 | event_graph = event_predictor.event_graph 19 | counter = Counter() 20 | for id, event in event_graph.events.items(): 21 | counter[id] = len(event.extracted_sents) 22 | return counter 23 | 24 | def analyze_events(self, event_predictor: EventPredictor): 25 | event_graph = event_predictor.event_graph 26 | while True: 27 | event_string = input("输入event_string") 28 | if event_string in ["quit", "q"]: 29 | break 30 | event = event_graph.find_event(event_string) 31 | if not event: 32 | print(f"no {event_string}") 33 | continue 34 | else: 35 | print(f"{event_string} next candidates: {event_graph.next_candidates(event.uuid, limit=3)}") 36 | print(f"{event_string} prev candidates: {event_graph.prev_events(event.uuid, limit=3)}") 37 | -------------------------------------------------------------------------------- /src/utils/task/event_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | @Notes: 5 | """ 6 | 7 | import sys 8 | from pathlib import Path 9 | 10 | from collections import Counter 11 | from typing import List 12 | 13 | from preprocessing.event_trigger.event_ontology import EventGraph 14 | from preprocessing.event_trigger.event_extractor import EventExtractor 15 | 16 | 17 | def line_to_event_list(line: str): 18 | clean_line = line.replace(EventGraph.event_s, "").replace(EventGraph.event_e, "").strip() 19 | events = [one.strip() for one in clean_line.split(EventGraph.event_sep)] 20 | return events 21 | 22 | 23 | def remove_empty_event_lines(src_lines: List[str], tgt_lines: List[str]): 24 | new_src = [] 25 | new_tgt = [] 26 | for s_line, t_line in zip(src_lines, tgt_lines): 27 | if len(line_to_event_list(s_line)) > 0: 28 | new_src.append(s_line) 29 | new_tgt.append(t_line) 30 | return new_src, new_tgt 31 | 32 | 33 | if __name__ == '__main__': 34 | pass 35 | -------------------------------------------------------------------------------- /src/utils/task/model_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import os 4 | from logging import getLogger 5 | from pathlib import Path 6 | from typing import Callable, Dict, Iterable, List, Tuple, Union 7 | 8 | import numpy as np 9 | from sacrebleu import corpus_bleu 10 | from torch import nn 11 | 12 | from transformers import EvalPrediction, PreTrainedTokenizer 13 | 14 | try: 15 | from fairseq.data.data_utils import batch_by_size 16 | 17 | FAIRSEQ_AVAILABLE = True 18 | except (ImportError, ModuleNotFoundError): 19 | FAIRSEQ_AVAILABLE = False 20 | 21 | from src.utils.nlg_eval_utils import calculate_rouge 22 | 23 | 24 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): 25 | """From fairseq""" 26 | if target.dim() == lprobs.dim() - 1: 27 | target = target.unsqueeze(-1) 28 | nll_loss = -lprobs.gather(dim=-1, index=target) 29 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 30 | if ignore_index is not None: 31 | pad_mask = target.eq(ignore_index) 32 | nll_loss.masked_fill_(pad_mask, 0.0) 33 | smooth_loss.masked_fill_(pad_mask, 0.0) 34 | else: 35 | nll_loss = nll_loss.squeeze(-1) 36 | smooth_loss = smooth_loss.squeeze(-1) 37 | 38 | nll_loss = nll_loss.sum() # mean()? Scared to break other math. 39 | smooth_loss = smooth_loss.sum() 40 | eps_i = epsilon / lprobs.size(-1) 41 | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss 42 | return loss, nll_loss 43 | 44 | 45 | def build_compute_metrics_fn(tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]: 46 | def non_pad_len(tokens: np.ndarray) -> int: 47 | return np.count_nonzero(tokens != tokenizer.pad_token_id) 48 | 49 | def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]: 50 | pred_strs = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True) 51 | label_strs = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True) 52 | pred_strs = list(map(str.strip, pred_strs)) 53 | label_strs = list(map(str.strip, label_strs)) 54 | return pred_strs, label_strs 55 | 56 | def summarization_metrics(pred: EvalPrediction) -> Dict: 57 | pred_str, label_str = decode_pred(pred) 58 | rouge: Dict = calculate_rouge(pred_str, label_str) 59 | non_pad_predictions = list(map(non_pad_len, pred.predictions)) 60 | summ_len = np.round(np.mean(non_pad_predictions), 1) 61 | rouge.update({"gen_len": summ_len}) 62 | return rouge 63 | 64 | compute_metrics_fn = summarization_metrics 65 | return compute_metrics_fn 66 | 67 | 68 | def trim_batch( 69 | input_ids, 70 | pad_token_id, 71 | attention_mask=None, 72 | ): 73 | """Remove columns that are populated exclusively by pad_token_id""" 74 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 75 | if attention_mask is None: 76 | return input_ids[:, keep_column_mask] 77 | else: 78 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) 79 | 80 | 81 | def flatten_list(generated_ids: List[List]): 82 | return [x for x in itertools.chain.from_iterable(generated_ids)] 83 | 84 | 85 | # Utilities for freezing parameters and checking whether they are frozen 86 | def freeze_params(model: nn.Module): 87 | """Set requires_grad=False for each of model.parameters()""" 88 | for param in model.parameters(): 89 | param.requires_grad = False 90 | 91 | 92 | def freeze_embeds(model): 93 | """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" 94 | model_type = model.config.model_type 95 | 96 | if model_type == "t5": 97 | freeze_params(model.shared) 98 | for model_ in [model.encoder, model.decoder]: 99 | freeze_params(model_.embed_tokens) 100 | elif model_type == "fsmt": 101 | for model_ in [model.model.encoder, model.model.decoder]: 102 | freeze_params(model_.embed_positions) 103 | freeze_params(model_.embed_tokens) 104 | else: 105 | freeze_params(model.model.shared) 106 | for model_ in [model.model.encoder, model.model.decoder]: 107 | freeze_params(model_.embed_positions) 108 | freeze_params(model_.embed_tokens) 109 | 110 | 111 | def grad_status(model: nn.Module) -> Iterable: 112 | return (par.requires_grad for par in model.parameters()) 113 | 114 | 115 | def any_requires_grad(model: nn.Module) -> bool: 116 | return any(grad_status(model)) 117 | 118 | 119 | def assert_all_frozen(model): 120 | model_grads: List[bool] = list(grad_status(model)) 121 | model_grads_int = list(map(int, model_grads)) 122 | n_require_grad = sum(model_grads_int) 123 | n_params = len(model_grads) 124 | assert not any(model_grads), f"{n_require_grad / n_params:.1%} of {n_params} weights require grad" 125 | 126 | 127 | def assert_not_all_frozen(model): 128 | model_grads: List[bool] = list(grad_status(model)) 129 | npars = len(model_grads) 130 | assert any(model_grads), f"none of {npars} weights require grad" 131 | 132 | 133 | def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]: 134 | """ 135 | Parse an argv list of unspecified command line args to a dict. 136 | Assumes all values are either numeric or boolean in the form of true/false. 137 | """ 138 | result = {} 139 | assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}" 140 | num_pairs = len(unparsed_args) // 2 141 | for pair_num in range(num_pairs): 142 | i = 2 * pair_num 143 | assert unparsed_args[i].startswith("--") 144 | if unparsed_args[i + 1].lower() == "true": 145 | value = True 146 | elif unparsed_args[i + 1].lower() == "false": 147 | value = False 148 | else: 149 | try: 150 | value = int(unparsed_args[i + 1]) 151 | except ValueError: 152 | value = float(unparsed_args[i + 1]) # this can raise another informative ValueError 153 | 154 | result[unparsed_args[i][2:]] = value 155 | return result 156 | 157 | 158 | def write_txt_file(ordered_tgt, path): 159 | f = Path(path).open("w") 160 | for ln in ordered_tgt: 161 | f.write(ln + "\n") 162 | f.flush() 163 | 164 | 165 | def chunks(lst, n): 166 | """Yield successive n-sized chunks from lst.""" 167 | for i in range(0, len(lst), n): 168 | yield lst[i: i + n] 169 | 170 | 171 | def smart_tokenizer_and_embedding_resize( 172 | special_tokens_dict: Dict, 173 | tokenizer, 174 | model, 175 | ): 176 | """Resize tokenizer and embedding. 177 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 178 | """ 179 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 180 | model.resize_token_embeddings(len(tokenizer)) 181 | 182 | if num_new_tokens > 0: 183 | input_embeddings = model.get_input_embeddings().weight.data 184 | output_embeddings = model.get_output_embeddings().weight.data 185 | 186 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 187 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 188 | 189 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 190 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 191 | -------------------------------------------------------------------------------- /src/utils/task/stat_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | @Notes: 5 | """ 6 | 7 | import os 8 | import sys 9 | from pathlib import Path 10 | 11 | from collections import Counter 12 | from typing import List 13 | 14 | from src.configuration.constants import BASE_DIR 15 | from src.utils.task.event_utils import line_to_event_list 16 | from preprocessing.event_trigger.event_extractor import EventExtractor 17 | 18 | 19 | def data_size(file: str): 20 | with open(file, "r", encoding="utf-8") as fr: 21 | return len(fr.readlines()) 22 | 23 | 24 | def text_stat(file: str): 25 | with open(file, "r", encoding="utf-8") as fr: 26 | file_lines = fr.readlines() 27 | stat = Counter() 28 | for line in file_lines: 29 | sents = line.strip().split(".") 30 | tokens = line.strip().split() 31 | stat["sents"] += len(sents) 32 | stat["tokens"] += len(tokens) 33 | return stat 34 | 35 | 36 | def event_stat(file: str): 37 | with open(file, "r", encoding="utf-8") as fr: 38 | event_lines = fr.readlines() 39 | stat = Counter() 40 | for line in event_lines: 41 | events = line_to_event_list(line) 42 | stat["events"] += len(events) 43 | return stat 44 | 45 | 46 | def event_graph_stat(event_extractor: EventExtractor): 47 | stat = Counter() 48 | event_graph = event_extractor.event_graph 49 | stat["events"] = event_graph.nodes_num 50 | stat["relations"] = event_graph.edges_num 51 | stat["avg_degree"] = event_graph.avg_degree 52 | stat["triggers"] = event_graph.triggers_num 53 | return stat 54 | 55 | 56 | def parse_files(src_file, tgt_file, event_file): 57 | counter = text_stat(src_file) + text_stat(tgt_file) + event_stat(event_file) 58 | # src 的 events 没算 59 | counter["data_size"] = data_size(src_file) 60 | counter["events"] += counter["data_size"] 61 | return counter 62 | 63 | 64 | def parse_event_graphs(dataset_name: str): 65 | save_path = f"{BASE_DIR}/output/event-trigger/cache/{dataset_name}_event_graph.pkl" 66 | if os.path.exists(save_path): 67 | print(f"extractor loaded from {save_path}") 68 | event_extractor = EventExtractor.load(save_path) 69 | return event_graph_stat(event_extractor) 70 | -------------------------------------------------------------------------------- /src/utils/total_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/src/utils/total_score.png -------------------------------------------------------------------------------- /src/utils/total_score_ratio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siat-nlp/HanFei/480b28f2651840ab80fcf0fb32f04b001d17ff24/src/utils/total_score_ratio.png -------------------------------------------------------------------------------- /src/utils/wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | @Notes: 5 | """ 6 | 7 | 8 | def print_done(desc: str): 9 | def wrapper(func): 10 | def decorate(*args): 11 | print(f"{desc}...", end="") 12 | func(*args) 13 | print('- done') 14 | 15 | return decorate 16 | 17 | return wrapper 18 | -------------------------------------------------------------------------------- /tasks/download_hf_models.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | git lfs install 4 | 5 | CURRENT_DIR=$(dirname $(readlink -f "$0")) 6 | BASE_DIR=$(dirname ${CURRENT_DIR}) 7 | TARGET_DIR="${BASE_DIR}/resources/external_models" 8 | 9 | cd ${TARGET_DIR} 10 | git clone https://huggingface.co/facebook/bart-base -------------------------------------------------------------------------------- /tasks/task/CoT_test.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import torch 3 | from transformers import AutoTokenizer, AutoModelForCausalLM 4 | import re 5 | import random 6 | import ast 7 | from collections import Counter 8 | import json 9 | from statistics import mode 10 | from tqdm import tqdm 11 | 12 | 13 | def prompt_without_equtation(dataset): 14 | question = dataset["question"] 15 | answer = list(map(lambda x: re.sub("<<.*>>", "", x), dataset["answer"])) 16 | answer = list(map(lambda x: re.sub("\n####", " The answer is:", x), answer)) 17 | input_list = [] 18 | prompt = [] 19 | for q in range(len(question)): 20 | prompt.append("\nQ: " + question[q] + "\nA: " + answer[q] + " \n") 21 | return prompt 22 | 23 | 24 | def add_prompts_to_input(sample_from_prompts, test_question): 25 | test_input_with_prompts = sample_from_prompts + "\nQ: " + test_question + "\nA: " 26 | return test_input_with_prompts 27 | 28 | 29 | def prompt_with_equtation(dataset): 30 | question = dataset["question"] 31 | answer = list( 32 | map(lambda x: re.sub("\n####", " The answer is:", x), dataset["answer"]) 33 | ) 34 | input_list = [] 35 | prompt = [] 36 | for q in range(len(question)): 37 | prompt.append("\nQ: " + question[q] + "\nA: " + answer[q] + " \n") 38 | return prompt 39 | 40 | 41 | MODEL_NAME = "/mntnfs/med_data5/zhanghongbo/general_pretrain/output/task/GPT_Code/best_tfmr" 42 | # MODEL_NAME = "gpt2-large" 43 | dataset = load_dataset("gsm8k", 'main', cache_dir='/mntnfs/med_data5/zhanghongbo/general_pretrain/cache_dir') 44 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, 45 | cache_dir='/mntnfs/med_data5/zhanghongbo/general_pretrain/cache_dir') 46 | tokenizer.pad_token = tokenizer.eos_token 47 | tokenizer.pad_token_id = 50256 48 | model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, 49 | cache_dir='/mntnfs/med_data5/zhanghongbo/general_pretrain/cache_dir') 50 | model = model.cuda() 51 | 52 | input_prompts = prompt_without_equtation(dataset["train"]) 53 | random.seed(42) 54 | number_of_prompts = 4 55 | sample_from_prompts = " ".join( 56 | map(str, random.sample(input_prompts, number_of_prompts)) 57 | ) 58 | 59 | test_same_prompts = list( 60 | map( 61 | lambda x: add_prompts_to_input(sample_from_prompts, x), 62 | dataset["test"]["question"], 63 | ) 64 | ) 65 | print(sample_from_prompts) 66 | 67 | input_prompts_with_equtation = prompt_with_equtation(dataset["train"]) 68 | random.seed(42) 69 | sample_from_prompts_with_equtation = " ".join( 70 | map(str, random.sample(input_prompts_with_equtation, number_of_prompts)) 71 | ) 72 | test_same_prompts_with_equtation = list( 73 | map( 74 | lambda x: add_prompts_to_input(sample_from_prompts_with_equtation, x), 75 | dataset["test"]["question"], 76 | ) 77 | ) 78 | 79 | # test_same_prompts = test_same_prompts[:20] 80 | # test_same_prompts_with_equtation = test_same_prompts_with_equtation[:20] 81 | 82 | results_greedy = [] 83 | for i in [ 84 | test_same_prompts, 85 | test_same_prompts_with_equtation, 86 | ]: 87 | predicted_data = [] 88 | for j in tqdm(i): 89 | inputs = tokenizer(j, return_tensors="pt")["input_ids"].cuda() 90 | outputs = model.generate(inputs, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id) 91 | predicted_data.append(tokenizer.decode(outputs[0])) 92 | results_greedy.append(predicted_data) 93 | 94 | with open("results_greedy", "w") as fp: 95 | json.dump(results_greedy, fp) 96 | 97 | 98 | def get_answer_from_generated_text(generated_text, question_from_test): 99 | answer = str() 100 | result = re.search( 101 | r"The answer is: \d+", generated_text.split(question_from_test)[-1] 102 | ) 103 | try: 104 | answer = re.search(r"\d+", result.group(0)).group(0) 105 | except: 106 | answer = "UNK" 107 | return answer 108 | 109 | 110 | test_dataset = dataset["test"]["question"] 111 | answers_test = [] 112 | for i in dataset["test"]["answer"]: 113 | answers_test.append(i.split("\n#### ")[-1]) 114 | 115 | all_answers = [] 116 | for dataset in results_greedy: 117 | results_chain_of_thoughts = [] 118 | for example in tqdm(range(len(dataset))): 119 | results_chain_of_thoughts.append( 120 | get_answer_from_generated_text(dataset[example], test_dataset[example]) 121 | ) 122 | all_answers.append(results_chain_of_thoughts) 123 | 124 | results_total_chain_of_thoughts = [] 125 | for answers in tqdm(range(len(all_answers))): 126 | results_by_method = [] 127 | for generated_answer in range(len(all_answers[answers])): 128 | if all_answers[answers][generated_answer] == answers_test[generated_answer]: 129 | results_by_method.append("yes") 130 | else: 131 | results_by_method.append("no") 132 | results_total_chain_of_thoughts.append(results_by_method) 133 | 134 | for i in tqdm(results_total_chain_of_thoughts): 135 | precision = int(Counter(i)["yes"] / len(i) * 100) 136 | print(f"Precision: {precision}%") 137 | 138 | 139 | def get_maj_answer(answers): 140 | while "UNK" in answers: 141 | answers.remove("UNK") 142 | if answers: 143 | mode_answer = mode(answers) 144 | else: 145 | mode_answer = "UNK" 146 | return mode_answer 147 | 148 | 149 | # Посчитаем совпадения в ответах на тесте и в сгенерированных примерах. Yes – ответ совпал, no – не совпал. 150 | # UNK в данном случае также относиться к no, так как ответа на выходе мы не получаем 151 | results_total_self_consistency = [] 152 | for answers in tqdm(range(len(all_answers_self_consistency))): 153 | results_by_method = [] 154 | for generated_answer in range(len(all_answers_self_consistency[answers])): 155 | if ( 156 | all_answers_self_consistency[answers][generated_answer] 157 | == answers_test[generated_answer] 158 | ): 159 | results_by_method.append("yes") 160 | else: 161 | results_by_method.append("no") 162 | results_total_self_consistency.append(results_by_method) 163 | 164 | for i in tqdm(results_total_self_consistency): 165 | precision = int(Counter(i)["yes"] / len(i) * 100) 166 | print(f"Precision: {precision}%") 167 | -------------------------------------------------------------------------------- /tasks/task/convert_metadata.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import jsonlines as jl 4 | import json 5 | import pandas as pd 6 | import random 7 | 8 | jlfile = '/fastdata/act21hz/general_pretrain/resources/PAQ.metadata.jsonl' 9 | converted_file = '/fastdata/act21hz/general_pretrain/resources/PAQ.QA.jsonl' 10 | df_file = '/fastdata/act21hz/general_pretrain/resources/PAQ.QA.tsv' 11 | df = pd.DataFrame(columns=['question', 'answer', 'passage_id']) 12 | results = [] 13 | for ln, line in enumerate(jl.open(jlfile)): 14 | data = line 15 | # data = {'question': data['question'], 'answer': data['answer'], 16 | # 'passage_id': [ans['passage_id'] for ans in data['answers']]} 17 | results.append( 18 | [data['question'], random.choice(data['answer']), [int(ans['passage_id']) for ans in data['answers']]]) 19 | print(f'Loaded {ln + 1} Items from {jlfile}', flush=True) if ln % 1000000 == 0 else None 20 | df = pd.DataFrame(results, columns=['question', 'answer', 'passage_id']) 21 | df.to_csv(df_file, sep='\t') 22 | -------------------------------------------------------------------------------- /tasks/task/corpus_stat.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import random 4 | import numpy as np 5 | from pathlib import Path 6 | 7 | FILE_PATH = Path(__file__).absolute() 8 | BASE_DIR = FILE_PATH.parent.parent.parent 9 | sys.path.insert(0, str(BASE_DIR)) # 在tasks文件夹中可以直接运行程序 10 | 11 | from datasets import load_dataset 12 | from transformers import PreTrainedTokenizerFast, GPT2Tokenizer 13 | from src.utils.file_utils import pickle_save, pickle_load 14 | from tqdm import tqdm 15 | 16 | data = load_dataset('NeelNanda/c4-code-tokenized-2b', cache_dir='resources/')['train'] 17 | 18 | neox_digits_tokenizer = PreTrainedTokenizerFast.from_pretrained('NeelNanda/gpt-neox-tokenizer-digits', 19 | cache_dir='resources/') 20 | gpt2_tokenizer = GPT2Tokenizer.from_pretrained('GPT2-large', cache_dir='cache_dir/') 21 | token_freq = {} 22 | for t in range(gpt2_tokenizer.vocab_size): 23 | token_freq[t] = 0 24 | count = 0 25 | for piece in tqdm(data, total=100000): 26 | if count == 100000: 27 | break 28 | piece_sentence = neox_digits_tokenizer.decode(piece['tokens'], skip_special_tokens=True) 29 | piece_gpt2 = gpt2_tokenizer(piece_sentence)['input_ids'] 30 | for token in piece_gpt2: 31 | token_freq[token] += 1 32 | count += 1 33 | 34 | token_freq = sorted(token_freq.items(), key=lambda kv: (kv[1], kv[0])) 35 | bucket_length = len(token_freq) // 10 36 | count = 0 37 | bucket_num = 0 38 | token_to_freq = {} 39 | for token in token_freq: 40 | if count > bucket_length: 41 | count = 0 42 | bucket_num += 1 43 | token_to_freq[token] = bucket_num 44 | count += 1 45 | pickle_save(token_to_freq, 'resources/token_to_freq.pkl') 46 | print('token to freq saved') 47 | -------------------------------------------------------------------------------- /tasks/task/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | @Notes: 5 | WANDB is Weights and Biases Logger: 6 | https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.loggers.wandb.html 7 | """ 8 | 9 | import sys 10 | import json 11 | import numpy as np 12 | from pathlib import Path 13 | 14 | from tqdm import tqdm 15 | 16 | FILE_PATH = Path(__file__).absolute() 17 | BASE_DIR = FILE_PATH.parent.parent.parent 18 | sys.path.insert(0, str(BASE_DIR)) # run code in any path 19 | 20 | from src.configuration.task.config_args import parse_args_for_config 21 | from src.utils.file_utils import copy_file_or_dir, output_obj_to_file, pickle_save, pickle_load, joblib_save, \ 22 | joblib_load 23 | from src.utils import nlg_eval_utils 24 | from train import EventTriggerTrainer 25 | 26 | 27 | class EventTriggerTester(EventTriggerTrainer): 28 | def __init__(self, args): 29 | # parameters 30 | super().__init__(args) 31 | self.generation_dir = self.experiment_output_dir / "gen_result" 32 | self.generation_dir.mkdir(parents=True, exist_ok=True) 33 | self.tokenizer = self.model.tokenizer 34 | self.model.eval() 35 | self.test_output = None 36 | self.src_file = None 37 | self.tgt_file = None 38 | self.gen_file = None 39 | self.eval_file = None 40 | 41 | # customized 42 | self.dataset = self.model.test_dataloader().dataset 43 | self.output_prefix = f"{self.model.model_name}" 44 | self.test_output_store_path = self.cache_dir.joinpath(f"{self.output_prefix}_test_output.pkl") 45 | self.ppl_file = self.generation_dir / f"{self.output_prefix}_ppl.job" 46 | self.acc_file = self.generation_dir / f"{self.output_prefix}_acc.job" 47 | self.freq_acc_file = self.generation_dir / f"{self.output_prefix}_freq_acc.job" 48 | 49 | def test(self, ckpt_path=None): 50 | # if ckpt_path is None: 51 | # ckpt_path = self.checkpoints[-1] 52 | self.pl_trainer.test(model=self.model, ckpt_path=None) 53 | 54 | def init_test_output(self): 55 | if self.test_output_store_path.exists(): 56 | print(f"test output loaded from {self.test_output_store_path}") 57 | self.test_output = pickle_load(self.test_output_store_path) 58 | if self.test_output is None: 59 | self.model.store_test_output = True 60 | self.test() 61 | self.test_output = self.model.test_output 62 | print(f"test output stored to {self.test_output_store_path}") 63 | pickle_save(self.test_output, self.test_output_store_path) 64 | if self.test_output is None: 65 | raise ValueError("self.test_output cannot be None") 66 | 67 | def generate(self): 68 | self.init_test_output() 69 | print(f"model {self.model.model_name} generating") 70 | print(f"src_file: {self.src_file}\ntgt_file: {self.tgt_file}\ngen_file: {self.gen_file}\n") 71 | print(f"test_loss: {self.test_output['test_loss']}") 72 | # print(f"metrics: {self.test_output['log']}") 73 | 74 | joblib_save(self.test_output['log']['test_logits_acc'], self.acc_file) 75 | joblib_save(self.test_output['log']['test_logits_ppl'], self.ppl_file) 76 | print(f"test_logits_acc saved at {self.acc_file}") 77 | print(f"test_logits_ppl saved at {self.ppl_file}") 78 | 79 | # def eval_output(self): 80 | # self.init_test_output() 81 | # test_logits_acc = self.test_output['log']['test_logits_acc'] 82 | # test_logits_ppl = self.test_output['log']['test_logits_ppl'] 83 | # 84 | # print("=" * 10) 85 | # 86 | # print(f"model {self.model.model_name} eval {self.gen_file}") 87 | # output_obj_to_file(json.dumps(metrics, indent=4), self.eval_file) 88 | # return metrics 89 | 90 | 91 | if __name__ == '__main__': 92 | hparams = parse_args_for_config() 93 | tester = EventTriggerTester(hparams) 94 | 95 | # generate predicted stories 96 | tester.generate() 97 | # tester.eval_output() 98 | -------------------------------------------------------------------------------- /tasks/task/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Desc: 3 | @Reference: 4 | - logger and WandLogger 5 | Weights and Biases is a third-party logger 6 | https://pytorch-lightning.readthedocs.io/en/latest/common/loggers.html 7 | @Notes: 8 | 9 | """ 10 | import sys 11 | import os 12 | import pytorch_lightning as pl 13 | 14 | from pathlib import Path 15 | 16 | FILE_PATH = Path(__file__).absolute() 17 | BASE_DIR = FILE_PATH.parent.parent.parent 18 | sys.path.insert(0, str(BASE_DIR)) # run code in any path 19 | 20 | from src.configuration.task.config_args import parse_args_for_config 21 | from src.models.task import gpt_code, llama_finetune 22 | from src.utils.wrapper import print_done 23 | from src.utils.string_utils import are_same_strings 24 | from src.models.basic_pl_trainer import BasicPLTrainer 25 | from src.modules.pl_callbacks import Seq2SeqLoggingCallback, Seq2SeqCheckpointCallback, SaveCheckpointEveryEpoch 26 | 27 | print("starting", flush=True) 28 | 29 | 30 | class EventTriggerTrainer(BasicPLTrainer): 31 | def __init__(self, args, trainer_name="event-trigger-trainer"): 32 | # parameters 33 | super().__init__(args, trainer_name=trainer_name) 34 | 35 | self._init_model(self.args) 36 | self._init_logger(self.args, self.model) 37 | self._init_pl_trainer(self.args, self.model, self.logger) 38 | 39 | @print_done(desc="Creating directories and fix random seeds") 40 | def _init_args(self, args): 41 | self.output_dir.mkdir(parents=True, exist_ok=True) 42 | self.experiment_output_dir.mkdir(parents=True, exist_ok=True) 43 | self.log_dir.mkdir(parents=True, exist_ok=True) 44 | self.save_dir.mkdir(parents=True, exist_ok=True) 45 | self.cache_dir.mkdir(parents=True, exist_ok=True) 46 | 47 | pl.seed_everything(args.seed, workers=True) # reproducibility 48 | 49 | @print_done(desc="initialize model") 50 | def _init_model(self, args): 51 | # automatically download from huggingface project 52 | print(f"model_path: {args.model_name_or_path}") 53 | # ============= gpt =============== 54 | if are_same_strings(args.model_name, "gpt"): 55 | self.model: gpt_code = gpt_code(args) 56 | # ============= llama =============== 57 | elif are_same_strings(args.model_name, "llama"): 58 | self.model: llama_finetune = llama_finetune(args) 59 | else: 60 | raise NotImplementedError(f"args.model_name: {args.model_name}") 61 | 62 | @print_done("set up pytorch lightning trainer") 63 | def _init_pl_trainer(self, args, model, logger): 64 | extra_callbacks = [] 65 | # self.checkpoint_callback = Seq2SeqCheckpointCallback( 66 | # output_dir=self.save_dir, 67 | # experiment_name=self.experiment_name, 68 | # monitor="val_loss", 69 | # save_top_k=args.save_top_k, 70 | # every_n_train_steps=args.every_n_train_steps, 71 | # verbose=args.ckpt_verbose, 72 | # ) 73 | 74 | # initialize pl_trainer 75 | # if args.gpus is not None and args.gpus > 1: 76 | # self.train_params["distributed_backend"] = "ddp" 77 | 78 | self.checkpoint_callback = SaveCheckpointEveryEpoch( 79 | output_dir=self.save_dir, 80 | experiment_name=self.experiment_name, 81 | last_k=args.save_top_k, 82 | save_weights_only=False, 83 | every_n_epochs=args.save_every_n_epochs, 84 | every_n_train_steps=args.save_every_n_steps 85 | ) 86 | 87 | self.pl_trainer: pl.Trainer = pl.Trainer.from_argparse_args( 88 | args, 89 | enable_model_summary=False, 90 | callbacks=[self.checkpoint_callback, Seq2SeqLoggingCallback(), pl.callbacks.ModelSummary(max_depth=1)] 91 | + extra_callbacks, 92 | logger=logger, 93 | **self.train_params, 94 | ) 95 | 96 | def train(self): 97 | self.auto_find_lr_rate() 98 | self.auto_find_batch_size() 99 | 100 | self.pl_trainer.logger.log_hyperparams(self.args) 101 | 102 | if self.checkpoints: 103 | # training 104 | best_ckpt = self.checkpoints[-1] 105 | self.pl_trainer.fit(self.model, ckpt_path=best_ckpt) 106 | else: 107 | self.pl_trainer.fit(self.model) 108 | 109 | 110 | if __name__ == '__main__': 111 | hparams = parse_args_for_config() 112 | trainer = EventTriggerTrainer(hparams) 113 | trainer.train() 114 | -------------------------------------------------------------------------------- /toy_hanfei.py: -------------------------------------------------------------------------------- 1 | from transformers import pipeline, AutoModelForCausalLM, AutoConfig, AutoTokenizer 2 | 3 | 4 | print("toy starting") 5 | 6 | model_path = "/HanFei/output/task/bloomz-7b1-mt-gpu8-1e5/models/global_step26308" 7 | tokenizer = AutoTokenizer.from_pretrained(model_path) 8 | model = AutoModelForCausalLM.from_pretrained(model_path, ignore_mismatched_sizes=True) 9 | 10 | input_text = '中国人民共和国刑法第一条是:' 11 | print(input_text) 12 | 13 | generator = pipeline('text-generation', model=model, tokenizer=tokenizer) 14 | print(generator(input_text, max_length=1024, min_length=256, num_return_sequences=1)) 15 | --------------------------------------------------------------------------------