├── .gitignore ├── LICENSE ├── README.md ├── conference_material ├── poster.pdf └── presentation.pdf ├── data ├── test_with_annotations.csv ├── train.csv └── val.csv ├── docs ├── Annotation Guidelines.txt ├── Comparison_with_MLEC-QA.jpg ├── example.png └── overall_comparison.jpg └── src ├── LoRA ├── finetune.py ├── generate.py ├── infer.py ├── scripts │ ├── finetune.sh │ ├── infer_ori.sh │ └── infer_sft.sh ├── templates │ ├── README.md │ └── med_template.json └── utils │ ├── README.md │ ├── __init__.py │ ├── data_format_transform.py │ └── prompter.py ├── evaluation ├── evaluate │ ├── bleu.py │ ├── metrics4rec.py │ ├── rouge.py │ └── utils.py ├── evaluate_chatglm_result.py ├── evaluate_ft_result.py ├── evaluate_gpt_result.py └── evaluate_lora_result.py ├── preprocess ├── data stats.ipynb ├── dataset_dist.pdf ├── generate_prompt.py └── prompt_templates.py ├── ptuning ├── arguments.py ├── deepspeed.json ├── main.py ├── prediction.sh ├── train.sh ├── trainer.py └── trainer_seq2seq.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [Junling Liu] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Note: If you are looking for a multimodal dataset, check out our new dataset, **ChiMed-VL-Instruction**, with 469,441 vision-language QA pairs: [https://paperswithcode.com/dataset/qilin-med-vl](https://paperswithcode.com/dataset/qilin-med-vl)) 2 | 3 | This paper was presented at NeurIPS 2023, New Orleans, Louisana. See here for the [poster](conference_material/poster.pdf) and [slides](conference_material/presentation.pdf). 4 | 5 | # Benchmarking Large Language Models on CMExam - A Comprehensive Chinese Medical Exam Dataset 6 | 7 | ## Introduction 8 | 9 | CMExam is a dataset sourced from the Chinese National Medical Licensing Examination. It consists of 60K+ multiple-choice questions and five additional question-wise annotations, including disease groups, clinical departments, medical disciplines, areas of competency, and question difficulty levels. Alongside the dataset, comprehensive benchmarks were conducted on representative LLMs on CMExam. 10 | 11 | 12 | 13 | ## Dataset Statistics 14 | | | Train | Val | Test | Total | 15 | |----------------------------|---------------|---------------|---------------|---------------| 16 | | Question | 54,497 | 6,811 | 6,811 | 68,119 | 17 | | Vocab | 4,545 | 3,620 | 3,599 | 4,629 | 18 | | Max Q tokens | 676 | 500 | 585 | 676 | 19 | | Max A tokens | 5 | 5 | 5 | 5 | 20 | | Max E tokens | 2,999 | 2,678 | 2,680 | 2,999 | 21 | | Avg Q tokens | 29.78 | 30.07 | 32.63 | 30.83 | 22 | | Avg A tokens | 1.08 | 1.07 | 1.07 | 1.07 | 23 | | Avg E tokens | 186.24 | 188.95 | 201.44 | 192.21 | 24 | | Median (Q1, Q3) Q tokens | 17 (12, 32) | 18 (12, 32) | 18 (12, 37) | 18 (12, 32) | 25 | | Median (Q1, Q3) A tokens | 1 (1, 1) | 1 (1, 1) | 1 (1, 1) | 1 (1, 1) | 26 | | Median (Q1, Q3) E tokens | 146 (69, 246) | 143 (65, 247) | 158 (80, 263) | 146 (69, 247) | 27 | 28 | \*Q: Question; A: Answer; E: Explanation 29 | 30 | ## Annotation Characteristics 31 | | Annotation Content | References | Unique values | 32 | |----------------------------|-----------------------------|---------------| 33 | | Disease Groups | The 11th revision of ICD-11 | 27 | 34 | | Clinical Departments | The Directory of Medical Institution Diagnostic and Therapeutic Categories (DMIDTC) | 36 | 35 | | Medical Disciplines | List of Graduate Education Disciplinary Majors (2022) | 7 | 36 | | Medical Competencies | Medical Professionals | 4 | 37 | | Difficulty Level | Human Performance | 5 | 38 | 39 | ## Benchmarks 40 | 41 | Alongside the dataset, we further conducted thorough experiments with representative LLMs and QA algorithms on CMExam. 42 | 43 | 44 | 45 | ## Deployment 46 | 47 | To deploy this project run 48 | 49 | ### Environment Setup 50 | ``` 51 | cd src 52 | pip install -r requirements.txt 53 | ``` 54 | ### Data Preprocess 55 | ``` 56 | cd preprocess 57 | python generate_prompt.py 58 | ``` 59 | 60 | ### Ptuning 61 | ``` 62 | cd ../ptuning 63 | bash train.sh 64 | bash prediction.sh 65 | ``` 66 | 67 | ### LoRA 68 | ``` 69 | cd ../LoRA 70 | bash ./scripts/finetune.sh 71 | bash ./scripts/infer_ori.sh 72 | bash ./scripts/infer_sft.sh 73 | ``` 74 | 75 | ### Evaluation 76 | ``` 77 | cd ../evaluation 78 | python evaluate_lora_results.py --csv_file_path path/to/csv/file 79 | ``` 80 | 81 | ## Side notes 82 | ### Limitations: 83 | - Excluding non-textual questions may introduce biases. 84 | - BLEU and ROUGE metrics are inadequate for fully assessing explanations; better expert analysis needed in future. 85 | ### Ethics in Data Collection: 86 | - Adheres to legal and ethical guidelines. 87 | - Authenticated and accurate for evaluating LLMs. 88 | - Intended for academic/research use only; commercial misuse prohibited. 89 | - Users should acknowledge dataset limitations and specific context. 90 | - Not for assessing individual medical competence or patient diagnosis. 91 | ### Future directions: 92 | - Translate to English (in-progress) 93 | - Include multimodal information (our new dataset ChiMed-Vision-Language-Instruction - 469,441 QA pairs: [https://paperswithcode.com/dataset/qilin-med-vl](https://paperswithcode.com/dataset/qilin-med-vl)) 94 | 95 | ## Citation 96 | Benchmarking Large Language Models on CMExam -- A Comprehensive Chinese Medical Exam Dataset 97 | https://arxiv.org/abs/2306.03030 98 | 99 | ``` 100 | @article{liu2023benchmarking, 101 | title={Benchmarking Large Language Models on CMExam--A Comprehensive Chinese Medical Exam Dataset}, 102 | author={Liu, Junling and Zhou, Peilin and Hua, Yining and Chong, Dading and Tian, Zhongyu and Liu, Andrew and Wang, Helin and You, Chenyu and Guo, Zhenhua and Zhu, Lei and others}, 103 | journal={arXiv preprint arXiv:2306.03030}, 104 | year={2023} 105 | } 106 | ``` 107 | -------------------------------------------------------------------------------- /conference_material/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/williamliujl/CMExam/fadb22c89beb1b7115dc36460ba792eb96b7b972/conference_material/poster.pdf -------------------------------------------------------------------------------- /conference_material/presentation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/williamliujl/CMExam/fadb22c89beb1b7115dc36460ba792eb96b7b972/conference_material/presentation.pdf -------------------------------------------------------------------------------- /docs/Annotation Guidelines.txt: -------------------------------------------------------------------------------- 1 | 1. Comprehensive Question Understanding: Prior to initiating the annotation process, meticulously comprehend the medical question, ensuring a holistic grasp of its context and significance. 2 | 2. Subject Categorization: Identify the precise subject or medical field that the question pertains to, such as cardiology, pediatrics, or pathology. 3 | 3. Principal Symptoms or Medical Conditions: Ascertain and pinpoint the primary symptoms or medical conditions expounded in the question. 4 | 4. Examination of Pertinent Factors: Scrutinize the question for any associated factors that might be present, including the severity of the ailment, its etiology, and patient history given in the question. 5 | 5. Appropriate Classification System Usage: Use the accurate classification system for annotation in alignment with the determined subject and symptoms. Suitable systems could encompass the 11th revision of the International Classification of Diseases (ICD-11), the Directory of Medical Institution Diagnostic and Therapeutic Categories (DMIDTC), and others. 6 | 6. Addressing Multiple Annotations: In scenarios where the question encompasses multiple symptoms or medical conditions, opt for the most related classification for annotation. 7 | 7. Ensuring High-Quality Annotations: Adhere to the guidelines and definitions within the chosen classification system. This diligence helps avert subjectivity and ambiguity, fostering precision in the annotations. 8 | 8. Navigating Queries and Uncertainties: Should any doubts or uncertainties emerge during the annotation process, consult the official documents and glossaries of the chosen classification system. Engaging in discussions with professionals is also advised to achieve clarity. 9 | 9. Resolving Discrepancies: When disagreements emerge between annotators, a collaborative discussion shall be initiated. The objective is to reach a consensus and unify the annotation decision. -------------------------------------------------------------------------------- /docs/Comparison_with_MLEC-QA.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/williamliujl/CMExam/fadb22c89beb1b7115dc36460ba792eb96b7b972/docs/Comparison_with_MLEC-QA.jpg -------------------------------------------------------------------------------- /docs/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/williamliujl/CMExam/fadb22c89beb1b7115dc36460ba792eb96b7b972/docs/example.png -------------------------------------------------------------------------------- /docs/overall_comparison.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/williamliujl/CMExam/fadb22c89beb1b7115dc36460ba792eb96b7b972/docs/overall_comparison.jpg -------------------------------------------------------------------------------- /src/LoRA/finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import List 4 | 5 | import fire 6 | import torch 7 | import transformers 8 | from datasets import load_dataset 9 | 10 | """ 11 | Unused imports: 12 | import torch.nn as nn 13 | import bitsandbytes as bnb 14 | """ 15 | 16 | from peft import ( 17 | LoraConfig, 18 | get_peft_model, 19 | get_peft_model_state_dict, 20 | prepare_model_for_int8_training, 21 | set_peft_model_state_dict, 22 | ) 23 | from transformers import LlamaForCausalLM, LlamaTokenizer 24 | 25 | from utils.prompter import Prompter 26 | from utils.data_format_transform import filter_and_convert 27 | 28 | 29 | def train( 30 | # model/data params 31 | base_model: str = "medalpaca/medalpaca-7b", # the only required argument 32 | data_path: str = "../../data/train_prompt.json", 33 | valid_data_path: str = "../../data/val_prompt.json", 34 | output_dir: str = "./lora-medalpaca", 35 | prompt_id: str = '1', 36 | # training hyperparams 37 | batch_size: int = 128, 38 | micro_batch_size: int = 8, 39 | num_epochs: int = 2, 40 | learning_rate: float = 3e-4, 41 | cutoff_len: int = 256, 42 | val_set_size: int = 500, 43 | sample: int = None, 44 | # lora hyperparams 45 | lora_r: int = 8, 46 | lora_alpha: int = 16, 47 | lora_dropout: float = 0.05, 48 | lora_target_modules: List[str] = [ 49 | "q_proj", 50 | "v_proj", 51 | ], 52 | # llm hyperparams 53 | train_on_inputs: bool = False, # if False, masks out inputs in loss 54 | group_by_length: bool = False, # faster, but produces an odd training loss curve 55 | # Others 56 | logging_steps: int = 8, 57 | eval_steps: int = 100, 58 | save_steps: int = 100, 59 | save_total_limit: int = 1000, 60 | # wandb params 61 | wandb_project: str = "llama_med", 62 | wandb_run_name: str = "", 63 | wandb_watch: str = "", # options: false | gradients | all 64 | wandb_log_model: str = "", # options: false | true 65 | resume_from_checkpoint: str = None, # either training checkpoint or final adapter 66 | prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca. 67 | ): 68 | if int(os.environ.get("LOCAL_RANK", 0)) == 0: 69 | print( 70 | f"Training model with params:\n" 71 | f"base_model: {base_model}\n" 72 | f"data_path: {data_path}\n" 73 | f"output_dir: {output_dir}\n" 74 | f"batch_size: {batch_size}\n" 75 | f"micro_batch_size: {micro_batch_size}\n" 76 | f"num_epochs: {num_epochs}\n" 77 | f"learning_rate: {learning_rate}\n" 78 | f"cutoff_len: {cutoff_len}\n" 79 | f"val_set_size: {val_set_size}\n" 80 | f"lora_r: {lora_r}\n" 81 | f"lora_alpha: {lora_alpha}\n" 82 | f"lora_dropout: {lora_dropout}\n" 83 | f"lora_target_modules: {lora_target_modules}\n" 84 | f"train_on_inputs: {train_on_inputs}\n" 85 | f"group_by_length: {group_by_length}\n" 86 | f"wandb_project: {wandb_project}\n" 87 | f"wandb_run_name: {wandb_run_name}\n" 88 | f"wandb_watch: {wandb_watch}\n" 89 | f"wandb_log_model: {wandb_log_model}\n" 90 | f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" 91 | f"prompt template: {prompt_template_name}\n" 92 | ) 93 | assert ( 94 | base_model 95 | ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" 96 | gradient_accumulation_steps = batch_size // micro_batch_size 97 | 98 | prompter = Prompter(prompt_template_name) 99 | 100 | device_map = "auto" 101 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 102 | ddp = world_size != 1 103 | if ddp: 104 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 105 | gradient_accumulation_steps = gradient_accumulation_steps // world_size 106 | 107 | # Check if parameter passed or if set within environ 108 | use_wandb = len(wandb_project) > 0 or ( 109 | "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 110 | ) 111 | # Only overwrite environ if wandb param passed 112 | if len(wandb_project) > 0: 113 | os.environ["WANDB_PROJECT"] = wandb_project 114 | if len(wandb_watch) > 0: 115 | os.environ["WANDB_WATCH"] = wandb_watch 116 | if len(wandb_log_model) > 0: 117 | os.environ["WANDB_LOG_MODEL"] = wandb_log_model 118 | 119 | model = LlamaForCausalLM.from_pretrained( 120 | base_model, 121 | load_in_8bit=True, 122 | torch_dtype=torch.float16, 123 | device_map=device_map, 124 | ) 125 | 126 | tokenizer = LlamaTokenizer.from_pretrained(base_model) 127 | 128 | tokenizer.pad_token_id = ( 129 | 0 # unk. we want this to be different from the eos token 130 | ) 131 | tokenizer.padding_side = "left" # Allow batched inference 132 | 133 | def tokenize(prompt, add_eos_token=True): 134 | # there's probably a way to do this with the tokenizer settings 135 | # but again, gotta move fast 136 | result = tokenizer( 137 | prompt, 138 | truncation=True, 139 | max_length=cutoff_len, 140 | padding=False, 141 | return_tensors=None, 142 | ) 143 | if ( 144 | result["input_ids"][-1] != tokenizer.eos_token_id 145 | and len(result["input_ids"]) < cutoff_len 146 | and add_eos_token 147 | ): 148 | result["input_ids"].append(tokenizer.eos_token_id) 149 | result["attention_mask"].append(1) 150 | 151 | result["labels"] = result["input_ids"].copy() 152 | 153 | return result 154 | 155 | def generate_and_tokenize_prompt(data_point): 156 | full_prompt = prompter.generate_prompt( 157 | data_point["instruction"], 158 | data_point["input"], 159 | data_point["output"], 160 | ) 161 | tokenized_full_prompt = tokenize(full_prompt) 162 | if not train_on_inputs: 163 | user_prompt = prompter.generate_prompt( 164 | data_point["instruction"], data_point["input"] 165 | ) 166 | tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) 167 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 168 | 169 | tokenized_full_prompt["labels"] = [ 170 | -100 171 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 172 | user_prompt_len: 173 | ] # could be sped up, probably 174 | return tokenized_full_prompt 175 | 176 | model = prepare_model_for_int8_training(model) 177 | 178 | config = LoraConfig( 179 | r=lora_r, 180 | lora_alpha=lora_alpha, 181 | target_modules=lora_target_modules, 182 | lora_dropout=lora_dropout, 183 | bias="none", 184 | task_type="CAUSAL_LM", 185 | ) 186 | model = get_peft_model(model, config) 187 | 188 | filtered_train_data_path = filter_and_convert(data_path, prompt_id, sample) 189 | if filtered_train_data_path.endswith(".json") or filtered_train_data_path.endswith(".jsonl"): 190 | data = load_dataset("json", data_files=filtered_train_data_path) 191 | else: 192 | data = load_dataset(filtered_train_data_path) 193 | filtered_val_data_path = filter_and_convert(valid_data_path, prompt_id) 194 | 195 | if os.path.exists(filtered_val_data_path): 196 | if filtered_val_data_path.endswith(".json") or filtered_val_data_path.endswith(".jsonl"): 197 | valid_data = load_dataset("json", data_files=filtered_val_data_path) 198 | else: 199 | valid_data = load_dataset(data_path) 200 | else: 201 | valid_data = None 202 | 203 | if resume_from_checkpoint: 204 | # Check the available weights and load them 205 | checkpoint_name = os.path.join( 206 | resume_from_checkpoint, "pytorch_model.bin" 207 | ) # Full checkpoint 208 | if not os.path.exists(checkpoint_name): 209 | checkpoint_name = os.path.join( 210 | resume_from_checkpoint, "adapter_model.bin" 211 | ) # only LoRA model - LoRA config above has to fit 212 | resume_from_checkpoint = ( 213 | False # So the trainer won't try loading its state 214 | ) 215 | # The two files above have a different name depending on how they were saved, but are actually the same. 216 | if os.path.exists(checkpoint_name): 217 | print(f"Restarting from {checkpoint_name}") 218 | adapters_weights = torch.load(checkpoint_name) 219 | model = set_peft_model_state_dict(model, adapters_weights) 220 | else: 221 | print(f"Checkpoint {checkpoint_name} not found") 222 | 223 | model.print_trainable_parameters() # Be more transparent about the % of trainable params. 224 | 225 | if val_set_size > 0 and not valid_data: 226 | train_val = data["train"].train_test_split( 227 | test_size=val_set_size, shuffle=True, seed=2023 228 | ) 229 | train_data = ( 230 | train_val["train"].shuffle().map(generate_and_tokenize_prompt) 231 | ) 232 | val_data = ( 233 | train_val["test"].shuffle().map(generate_and_tokenize_prompt) 234 | ) 235 | elif val_set_size > 0 and valid_data: 236 | train_data = ( 237 | data["train"].shuffle(seed=2023).map(generate_and_tokenize_prompt) 238 | ) 239 | val_sample = valid_data["train"].train_test_split( 240 | test_size=val_set_size, shuffle=True, seed=2023 241 | ) 242 | val_data = ( 243 | val_sample["test"].shuffle().map(generate_and_tokenize_prompt) 244 | ) 245 | else: 246 | train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) 247 | val_data = None 248 | 249 | if not ddp and torch.cuda.device_count() > 1: 250 | # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available 251 | model.is_parallelizable = True 252 | model.model_parallel = True 253 | 254 | trainer = transformers.Trainer( 255 | model=model, 256 | train_dataset=train_data, 257 | eval_dataset=val_data, 258 | args=transformers.TrainingArguments( 259 | per_device_train_batch_size=micro_batch_size, 260 | gradient_accumulation_steps=gradient_accumulation_steps, 261 | warmup_ratio=0.1, 262 | num_train_epochs=num_epochs, 263 | learning_rate=learning_rate, 264 | fp16=True, 265 | logging_steps=logging_steps, 266 | optim="adamw_torch", 267 | evaluation_strategy="steps" if val_set_size > 0 else "no", 268 | save_strategy="steps", 269 | eval_steps=eval_steps if val_set_size > 0 else None, 270 | save_steps=save_steps, 271 | output_dir=output_dir, 272 | save_total_limit=save_total_limit, 273 | load_best_model_at_end=True if val_set_size > 0 else False, 274 | ddp_find_unused_parameters=False if ddp else None, 275 | group_by_length=group_by_length, 276 | report_to="wandb" if use_wandb else None, 277 | run_name=wandb_run_name if use_wandb else None, 278 | ), 279 | data_collator=transformers.DataCollatorForSeq2Seq( 280 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 281 | ), 282 | ) 283 | model.config.use_cache = False 284 | 285 | old_state_dict = model.state_dict 286 | model.state_dict = ( 287 | lambda self, *_, **__: get_peft_model_state_dict( 288 | self, old_state_dict() 289 | ) 290 | ).__get__(model, type(model)) 291 | 292 | if torch.__version__ >= "2" and sys.platform != "win32": 293 | model = torch.compile(model) 294 | 295 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 296 | 297 | model.save_pretrained(output_dir) 298 | 299 | print( 300 | "\n If there's a warning about missing keys above, please disregard :)" 301 | ) 302 | 303 | 304 | if __name__ == "__main__": 305 | fire.Fire(train) 306 | -------------------------------------------------------------------------------- /src/LoRA/generate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import fire 4 | import gradio as gr 5 | import torch 6 | import transformers 7 | from peft import PeftModel 8 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer 9 | 10 | from utils.prompter import Prompter 11 | 12 | if torch.cuda.is_available(): 13 | device = "cuda" 14 | else: 15 | device = "cpu" 16 | 17 | try: 18 | if torch.backends.mps.is_available(): 19 | device = "mps" 20 | except: # noqa: E722 21 | pass 22 | 23 | 24 | def main( 25 | load_8bit: bool = False, 26 | base_model: str = "decapoda-research/llama-7b-hf", 27 | lora_weights: str = "tloen/alpaca-lora-7b", 28 | prompt_template: str = "med_template", # The prompt template to use, will default to alpaca. 29 | server_name: str = "0.0.0.0", # Allows to listen on all interfaces by providing '0.0.0.0' 30 | share_gradio: bool = True, 31 | ): 32 | assert ( 33 | base_model 34 | ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" 35 | 36 | prompter = Prompter(prompt_template) 37 | tokenizer = LlamaTokenizer.from_pretrained(base_model) 38 | if device == "cuda": 39 | model = LlamaForCausalLM.from_pretrained( 40 | base_model, 41 | load_in_8bit=load_8bit, 42 | torch_dtype=torch.float16, 43 | device_map="auto", 44 | ) 45 | model = PeftModel.from_pretrained( 46 | model, 47 | lora_weights, 48 | torch_dtype=torch.float16, 49 | ) 50 | elif device == "mps": 51 | model = LlamaForCausalLM.from_pretrained( 52 | base_model, 53 | device_map={"": device}, 54 | torch_dtype=torch.float16, 55 | ) 56 | model = PeftModel.from_pretrained( 57 | model, 58 | lora_weights, 59 | device_map={"": device}, 60 | torch_dtype=torch.float16, 61 | ) 62 | else: 63 | model = LlamaForCausalLM.from_pretrained( 64 | base_model, device_map={"": device}, low_cpu_mem_usage=True 65 | ) 66 | model = PeftModel.from_pretrained( 67 | model, 68 | lora_weights, 69 | device_map={"": device}, 70 | ) 71 | 72 | # unwind broken decapoda-research config 73 | model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk 74 | model.config.bos_token_id = 1 75 | model.config.eos_token_id = 2 76 | 77 | if not load_8bit: 78 | model.half() # seems to fix bugs for some users. 79 | 80 | model.eval() 81 | if torch.__version__ >= "2" and sys.platform != "win32": 82 | model = torch.compile(model) 83 | 84 | def evaluate( 85 | instruction, 86 | input=None, 87 | temperature=0.1, 88 | top_p=0.75, 89 | top_k=40, 90 | num_beams=4, 91 | max_new_tokens=128, 92 | **kwargs, 93 | ): 94 | prompt = prompter.generate_prompt(instruction, input) 95 | inputs = tokenizer(prompt, return_tensors="pt") 96 | input_ids = inputs["input_ids"].to(device) 97 | generation_config = GenerationConfig( 98 | temperature=temperature, 99 | top_p=top_p, 100 | top_k=top_k, 101 | num_beams=num_beams, 102 | **kwargs, 103 | ) 104 | with torch.no_grad(): 105 | generation_output = model.generate( 106 | input_ids=input_ids, 107 | generation_config=generation_config, 108 | return_dict_in_generate=True, 109 | output_scores=True, 110 | max_new_tokens=max_new_tokens, 111 | ) 112 | s = generation_output.sequences[0] 113 | output = tokenizer.decode(s) 114 | return prompter.get_response(output) 115 | 116 | gr.Interface( 117 | fn=evaluate, 118 | inputs=[ 119 | gr.components.Textbox( 120 | lines=2, 121 | label="Instruction", 122 | placeholder="Tell me about alpacas.", 123 | ), 124 | gr.components.Textbox(lines=2, label="Input", placeholder="none"), 125 | gr.components.Slider( 126 | minimum=0, maximum=1, value=0.1, label="Temperature" 127 | ), 128 | gr.components.Slider( 129 | minimum=0, maximum=1, value=0.75, label="Top p" 130 | ), 131 | gr.components.Slider( 132 | minimum=0, maximum=100, step=1, value=40, label="Top k" 133 | ), 134 | gr.components.Slider( 135 | minimum=1, maximum=4, step=1, value=4, label="Beams" 136 | ), 137 | gr.components.Slider( 138 | minimum=1, maximum=2000, step=1, value=128, label="Max tokens" 139 | ), 140 | ], 141 | outputs=[ 142 | gr.inputs.Textbox( 143 | lines=5, 144 | label="Output", 145 | ) 146 | ], 147 | title="🦙🌲 Alpaca-LoRA", 148 | description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).", # noqa: E501 149 | ).launch(server_name=server_name, share=share_gradio) 150 | # Old testing code follows. 151 | 152 | """ 153 | # testing code for readme 154 | for instruction in [ 155 | "Tell me about alpacas.", 156 | "Tell me about the president of Mexico in 2019.", 157 | "Tell me about the king of France in 2019.", 158 | "List all Canadian provinces in alphabetical order.", 159 | "Write a Python program that prints the first 10 Fibonacci numbers.", 160 | "Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.", # noqa: E501 161 | "Tell me five words that rhyme with 'shock'.", 162 | "Translate the sentence 'I have no mouth but I must scream' into Spanish.", 163 | "Count up from 1 to 500.", 164 | ]: 165 | print("Instruction:", instruction) 166 | print("Response:", evaluate(instruction)) 167 | print() 168 | """ 169 | 170 | 171 | if __name__ == "__main__": 172 | fire.Fire(main) 173 | -------------------------------------------------------------------------------- /src/LoRA/infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import fire 5 | import torch 6 | import pandas as pd 7 | from peft import PeftModel 8 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer 9 | from utils.prompter import Prompter 10 | from tqdm import tqdm 11 | 12 | device = "cuda" if torch.cuda.is_available() else "cpu" 13 | 14 | class InferenceEngine: 15 | 16 | def __init__(self): 17 | self.device = device 18 | 19 | def load_instruction(self, instruct_dir): 20 | input_data = [] 21 | with open(instruct_dir, "r") as f: 22 | lines = f.readlines() 23 | for line in lines: 24 | line = line.strip() 25 | d = json.loads(line) 26 | input_data.append(d) 27 | return input_data 28 | 29 | def load_instruction_from_csv(self, instruct_dir, prompt_idx='all'): 30 | input_data = [] 31 | df = pd.read_csv(instruct_dir, dtype='str') 32 | if prompt_idx!='all': 33 | df = df[df['prompt_idx'] == str(prompt_idx)] 34 | dict_from_df = df.to_dict(orient='index') 35 | for key,value in dict_from_df.items(): 36 | data = {} 37 | data['output'] = value['completion'].strip() 38 | data['instruction'] = value['prompt'].strip() 39 | input_data.append(data) 40 | return input_data, df 41 | 42 | def evaluate(self, 43 | batch, 44 | input=None, 45 | **kwargs, 46 | ): 47 | prompts = [self.prompter.generate_prompt(data["instruction"], input) for data in batch] 48 | inputs = self.tokenizer(prompts, return_tensors="pt", padding=True).to(device) 49 | generation_config = GenerationConfig( 50 | temperature=self.temperature, 51 | top_p=self.top_p, 52 | top_k=self.top_k, 53 | num_beams=self.num_beams, 54 | **kwargs, 55 | ) 56 | with torch.no_grad(): 57 | generation_output = self.model.generate( 58 | **inputs, 59 | generation_config=generation_config, 60 | # return_dict_in_generate=True, 61 | # output_scores=True, 62 | max_new_tokens=self.max_new_tokens, 63 | num_return_sequences=self.num_return_sequences, 64 | ) 65 | outputs = self.tokenizer.batch_decode(generation_output, skip_special_tokens=True) 66 | return [self.prompter.get_response(output) for output in outputs] 67 | 68 | def infer_from_csv(self, instruct_dir, output_dir, prompt_id): 69 | input_data, df_ori = self.load_instruction_from_csv(instruct_dir, prompt_id) 70 | df_ori.reset_index(drop=True, inplace=True) 71 | col_name = 'model_result' 72 | batched_data = [input_data[i:i+self.batch_size] for i in range(0, len(input_data), self.batch_size)] 73 | model_output_dict = {col_name:[]} 74 | for batch in tqdm(batched_data): 75 | instructions = [data["instruction"] for data in batch] 76 | outputs = self.evaluate(batch) 77 | for i, output in enumerate(outputs): 78 | instruction = instructions[i] 79 | golden_output = batch[i]["output"] 80 | print("###infering###") 81 | print("###instruction###") 82 | print(instruction) 83 | print("###golden output###") 84 | print(golden_output) 85 | print("###model output###") 86 | print(output) 87 | model_output_dict[col_name].append(output) 88 | new_df = pd.DataFrame(model_output_dict) 89 | merged_df = pd.concat([df_ori, new_df], axis=1) 90 | merged_df.to_csv(output_dir + self.output_file_name, index=False) 91 | 92 | def run(self, 93 | load_8bit=False, 94 | base_model="medalpaca/medalpaca-7b", 95 | instruct_dir="../../data/test_prompt.csv", 96 | prompt_id="4", 97 | output_dir="output/", 98 | output_file_name="output.csv", 99 | use_lora=False, 100 | lora_weights="tloen/alpaca-lora-7b", 101 | prompt_template="med_template", 102 | batch_size=4, 103 | temperature=0.1, 104 | top_p=0.75, 105 | top_k=40, 106 | num_beams=4, 107 | max_new_tokens=32, 108 | num_return_sequences=1 109 | ): 110 | self.output_file_name = output_file_name 111 | self.prompter = Prompter(prompt_template) 112 | self.tokenizer = LlamaTokenizer.from_pretrained(base_model, padding_side="left") 113 | self.model = LlamaForCausalLM.from_pretrained( 114 | base_model, 115 | load_in_8bit=load_8bit, 116 | torch_dtype=torch.float16, 117 | device_map="auto", 118 | ) 119 | self.batch_size = batch_size 120 | self.temperature = temperature 121 | self.top_p = top_p 122 | self.top_k = top_k 123 | self.num_beams = num_beams 124 | self.max_new_tokens = max_new_tokens 125 | self.num_return_sequences = num_return_sequences 126 | 127 | if use_lora: 128 | print(f"using lora {lora_weights}") 129 | self.model = PeftModel.from_pretrained( 130 | self.model, 131 | lora_weights, 132 | torch_dtype=torch.float16, 133 | ) 134 | # unwind broken decapoda-research config 135 | self.model.config.pad_token_id = self.tokenizer.pad_token_id = 0 # unk 136 | self.model.config.bos_token_id = self.tokenizer.bos_token_id 137 | self.model.config.eos_token_id = self.tokenizer.eos_token_id 138 | if not load_8bit: 139 | self.model.half() # seems to fix bugs for some users. 140 | 141 | self.model.eval() 142 | 143 | if torch.__version__ >= "2" and sys.platform != "win32": 144 | self.model = torch.compile(self.model) 145 | 146 | if instruct_dir != "": 147 | filename, file_extension = os.path.splitext(instruct_dir) 148 | file_extension_without_dot = file_extension[1:] 149 | if file_extension_without_dot == 'json': 150 | self.infer_from_json(instruct_dir) 151 | elif file_extension_without_dot == 'csv': 152 | self.infer_from_csv(instruct_dir, output_dir, prompt_id) 153 | else: 154 | raise ValueError 155 | else: 156 | for instruction in [ 157 | "我感冒了,怎么治疗", 158 | "一个患有肝衰竭综合征的病人,除了常见的临床表现外,还有哪些特殊的体征?", 159 | "急性阑尾炎和缺血性心脏病的多发群体有何不同?", 160 | "小李最近出现了心动过速的症状,伴有轻度胸痛。体检发现P-R间期延长,伴有T波低平和ST段异常", 161 | ]: 162 | print("Instruction:", instruction) 163 | print("Response:", self.evaluate(instruction)) 164 | print() 165 | 166 | if __name__ == "__main__": 167 | fire.Fire(InferenceEngine().run) 168 | -------------------------------------------------------------------------------- /src/LoRA/scripts/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | prompt_template="med_template" 3 | prompt_id="1" 4 | num_epochs=10 5 | # LLaMA-CMExam 6 | exp_tag="LLaMA-CMExam" 7 | python finetune.py \ 8 | --base_model 'decapoda-research/llama-7b-hf' \ 9 | --data_path '../../data/train_prompt.json' \ 10 | --valid_data_path '../../data/val_prompt.json' \ 11 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \ 12 | --prompt_template_name $prompt_template \ 13 | --micro_batch_size 8 \ 14 | --batch_size 128 \ 15 | --wandb_run_name $exp_tag \ 16 | --prompt_id $prompt_id \ 17 | --num_epochs $num_epochs \ 18 | --cutoff_len 256 \ 19 | --learning_rate 3e-4 \ 20 | --lora_r 8 \ 21 | --lora_alpha 16 22 | # Alpaca-CMExam 23 | exp_tag="Alpaca-CMExam" 24 | python finetune.py \ 25 | --base_model 'decapoda-research/llama-7b-hf' \ 26 | --resume_from_checkpoint 'alpaca-lora-7b' \ 27 | --data_path '../../data/train_prompt.json' \ 28 | --valid_data_path '../../data/val_prompt.json' \ 29 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \ 30 | --prompt_template_name $prompt_template \ 31 | --micro_batch_size 8 \ 32 | --batch_size 128 \ 33 | --wandb_run_name $exp_tag \ 34 | --prompt_id $prompt_id \ 35 | --num_epochs $num_epochs \ 36 | --cutoff_len 256 \ 37 | --learning_rate 3e-4 \ 38 | --lora_r 16 \ 39 | --lora_alpha 16 \ 40 | --lora_target_modules='[q_proj,k_proj,v_proj,o_proj]' 41 | # Huatuo-CMExam 42 | exp_tag="Huatuo-CMExam" 43 | python finetune.py \ 44 | --base_model 'decapoda-research/llama-7b-hf' \ 45 | --resume_from_checkpoint 'lora-alpaca-med' \ 46 | --data_path '../../data/train_prompt.json' \ 47 | --valid_data_path '../../data/val_prompt.json' \ 48 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \ 49 | --prompt_template_name $prompt_template \ 50 | --micro_batch_size 8 \ 51 | --batch_size 128 \ 52 | --wandb_run_name $exp_tag \ 53 | --prompt_id $prompt_id \ 54 | --num_epochs $num_epochs \ 55 | --cutoff_len 256 \ 56 | --learning_rate 3e-4 \ 57 | --lora_r 8 \ 58 | --lora_alpha 16 59 | # MedAlpaca-CMExam 60 | exp_tag="Medalpaca-CMExam" 61 | python finetune.py \ 62 | --base_model 'medalpaca/medalpaca-7b' \ 63 | --data_path '../../data/train_prompt.json' \ 64 | --valid_data_path '../../data/val_prompt.json' \ 65 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \ 66 | --prompt_template_name $prompt_template \ 67 | --micro_batch_size 8 \ 68 | --batch_size 128 \ 69 | --wandb_run_name $exp_tag \ 70 | --prompt_id $prompt_id \ 71 | --num_epochs $num_epochs \ 72 | --cutoff_len 256 \ 73 | --learning_rate 3e-4 \ 74 | --lora_r 8 \ 75 | --lora_alpha 16 76 | # 77 | prompt_id="4" 78 | num_epochs=1 79 | # LLaMA-CMExam 80 | exp_tag="LLaMA-CMExam" 81 | python finetune.py \ 82 | --base_model 'decapoda-research/llama-7b-hf' \ 83 | --data_path '../../data/train_prompt.json' \ 84 | --valid_data_path '../../data/val_prompt.json' \ 85 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \ 86 | --prompt_template_name $prompt_template \ 87 | --micro_batch_size 8 \ 88 | --batch_size 128 \ 89 | --wandb_run_name $exp_tag \ 90 | --prompt_id $prompt_id \ 91 | --num_epochs $num_epochs \ 92 | --cutoff_len 256 \ 93 | --learning_rate 3e-4 \ 94 | --lora_r 8 \ 95 | --lora_alpha 16 96 | # Alpaca-CMExam 97 | exp_tag="Alpaca-CMExam" 98 | python finetune.py \ 99 | --base_model 'decapoda-research/llama-7b-hf' \ 100 | --resume_from_checkpoint 'alpaca-lora-7b' \ 101 | --data_path '../../data/train_prompt.json' \ 102 | --valid_data_path '../../data/val_prompt.json' \ 103 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \ 104 | --prompt_template_name $prompt_template \ 105 | --micro_batch_size 8 \ 106 | --batch_size 128 \ 107 | --wandb_run_name $exp_tag \ 108 | --prompt_id $prompt_id \ 109 | --num_epochs $num_epochs \ 110 | --cutoff_len 256 \ 111 | --learning_rate 3e-4 \ 112 | --lora_r 16 \ 113 | --lora_alpha 16 \ 114 | --lora_target_modules='[q_proj,k_proj,v_proj,o_proj]' 115 | # Huatuo-CMExam 116 | exp_tag="Huatuo-CMExam" 117 | python finetune.py \ 118 | --base_model 'decapoda-research/llama-7b-hf' \ 119 | --resume_from_checkpoint 'lora-alpaca-med' \ 120 | --data_path '../../data/train_prompt.json' \ 121 | --valid_data_path '../../data/val_prompt.json' \ 122 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \ 123 | --prompt_template_name $prompt_template \ 124 | --micro_batch_size 8 \ 125 | --batch_size 128 \ 126 | --wandb_run_name $exp_tag \ 127 | --prompt_id $prompt_id \ 128 | --num_epochs $num_epochs \ 129 | --cutoff_len 256 \ 130 | --learning_rate 3e-4 \ 131 | --lora_r 8 \ 132 | --lora_alpha 16 133 | # MedAlpaca-CMExam 134 | exp_tag="Medalpaca-CMExam" 135 | python finetune.py \ 136 | --base_model 'medalpaca/medalpaca-7b' \ 137 | --data_path '../../data/train_prompt.json' \ 138 | --valid_data_path '../../data/val_prompt.json' \ 139 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \ 140 | --prompt_template_name $prompt_template \ 141 | --micro_batch_size 8 \ 142 | --batch_size 128 \ 143 | --wandb_run_name $exp_tag \ 144 | --prompt_id $prompt_id \ 145 | --num_epochs $num_epochs \ 146 | --cutoff_len 256 \ 147 | --learning_rate 3e-4 \ 148 | --lora_r 8 \ 149 | --lora_alpha 16 -------------------------------------------------------------------------------- /src/LoRA/scripts/infer_ori.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # medalpaca prompt 1 3 | CUDA_VISIBLE_DEVICES=0 python infer.py \ 4 | --base_model 'medalpaca/medalpaca-7b' \ 5 | --use_lora False \ 6 | --instruct_dir '../../data/test_prompt.csv' \ 7 | --prompt_template 'med_template' \ 8 | --output_file_name 'medalpaca_1.csv' \ 9 | --prompt_id '1' \ 10 | --batch_size 4 \ 11 | --num_beams 1 \ 12 | --max_new_tokens 64 13 | # medalpaca prompt 4 14 | CUDA_VISIBLE_DEVICES=0 python infer.py \ 15 | --base_model 'medalpaca/medalpaca-7b' \ 16 | --use_lora False \ 17 | --instruct_dir '../../data/test_prompt.csv' \ 18 | --prompt_template 'med_template' \ 19 | --output_file_name 'medalpaca_4.csv' \ 20 | --prompt_id '4' \ 21 | --batch_size 2 \ 22 | --num_beams 4 \ 23 | --max_new_tokens 256 -------------------------------------------------------------------------------- /src/LoRA/scripts/infer_sft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # LLaMA-CMExam prompt 1 3 | model_name='LLaMA-CMExam' 4 | prompt_id='1' 5 | CUDA_VISIBLE_DEVICES=0 python infer.py \ 6 | --base_model 'decapoda-research/llama-7b-hf' \ 7 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \ 8 | --use_lora True \ 9 | --instruct_dir '../../data/test_prompt.csv' \ 10 | --prompt_template 'med_template' \ 11 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \ 12 | --prompt_id ${prompt_id} \ 13 | --batch_size 4 \ 14 | --num_beams 1 \ 15 | --max_new_tokens 32 16 | # LLaMA-CMExam prompt 4 17 | model_name='LLaMA-CMExam' 18 | prompt_id='4' 19 | CUDA_VISIBLE_DEVICES=0 python infer.py \ 20 | --base_model 'decapoda-research/llama-7b-hf' \ 21 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \ 22 | --use_lora True \ 23 | --instruct_dir '../../data/test_prompt.csv' \ 24 | --prompt_template 'med_template' \ 25 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \ 26 | --prompt_id ${prompt_id} \ 27 | --batch_size 4 \ 28 | --num_beams 4 \ 29 | --max_new_tokens 256 30 | # Alpaca-CMExam prompt 1 31 | model_name='Alpaca-CMExam' 32 | prompt_id='1' 33 | CUDA_VISIBLE_DEVICES=0 python infer.py \ 34 | --base_model 'decapoda-research/llama-7b-hf' \ 35 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \ 36 | --use_lora True \ 37 | --instruct_dir '../../data/test_prompt.csv' \ 38 | --prompt_template 'med_template' \ 39 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \ 40 | --prompt_id ${prompt_id} \ 41 | --batch_size 4 \ 42 | --num_beams 1 \ 43 | --max_new_tokens 32 44 | # Alpaca-CMExam prompt 4 45 | model_name='Alpaca-CMExam' 46 | prompt_id='4' 47 | CUDA_VISIBLE_DEVICES=0 python infer.py \ 48 | --base_model 'decapoda-research/llama-7b-hf' \ 49 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \ 50 | --use_lora True \ 51 | --instruct_dir '../../data/test_prompt.csv' \ 52 | --prompt_template 'med_template' \ 53 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \ 54 | --prompt_id ${prompt_id} \ 55 | --batch_size 4 \ 56 | --num_beams 4 \ 57 | --max_new_tokens 256 58 | # Huatuo-CMExam prompt 1 59 | model_name='Huatuo-CMExam' 60 | prompt_id='1' 61 | CUDA_VISIBLE_DEVICES=0 python infer.py \ 62 | --base_model 'decapoda-research/llama-7b-hf' \ 63 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \ 64 | --use_lora True \ 65 | --instruct_dir '../../data/test_prompt.csv' \ 66 | --prompt_template 'med_template' \ 67 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \ 68 | --prompt_id ${prompt_id} \ 69 | --batch_size 4 \ 70 | --num_beams 1 \ 71 | --max_new_tokens 32 72 | # Huatuo-CMExam prompt 4 73 | model_name='Huatuo-CMExam' 74 | prompt_id='4' 75 | CUDA_VISIBLE_DEVICES=0 python infer.py \ 76 | --base_model 'decapoda-research/llama-7b-hf' \ 77 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \ 78 | --use_lora True \ 79 | --instruct_dir '../../data/test_prompt.csv' \ 80 | --prompt_template 'med_template' \ 81 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \ 82 | --prompt_id ${prompt_id} \ 83 | --batch_size 4 \ 84 | --num_beams 4 \ 85 | --max_new_tokens 256 86 | # Medalpaca-CMExam prompt 1 87 | model_name='Medalpaca-CMExam' 88 | prompt_id='1' 89 | CUDA_VISIBLE_DEVICES=0 python infer.py \ 90 | --base_model 'medalpaca/medalpaca-7b' \ 91 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \ 92 | --use_lora True \ 93 | --instruct_dir '../../data/test_prompt.csv' \ 94 | --prompt_template 'med_template' \ 95 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \ 96 | --prompt_id ${prompt_id} \ 97 | --batch_size 4 \ 98 | --num_beams 1 \ 99 | --max_new_tokens 32 100 | # Medalpaca-CMExam prompt 4 101 | model_name='Medalpaca-CMExam' 102 | prompt_id='4' 103 | CUDA_VISIBLE_DEVICES=0 python infer.py \ 104 | --base_model 'medalpaca/medalpaca-7b' \ 105 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \ 106 | --use_lora True \ 107 | --instruct_dir '../../data/test_prompt.csv' \ 108 | --prompt_template 'med_template' \ 109 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \ 110 | --prompt_id ${prompt_id} \ 111 | --batch_size 4 \ 112 | --num_beams 4 \ 113 | --max_new_tokens 256 -------------------------------------------------------------------------------- /src/LoRA/templates/README.md: -------------------------------------------------------------------------------- 1 | # Prompt templates 2 | 3 | This directory contains template styles for the prompts used to finetune LoRA models. 4 | 5 | ## Format 6 | 7 | A template is described via a JSON file with the following keys: 8 | 9 | - `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders. 10 | - `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders. 11 | - `description`: A short description of the template, with possible use cases. 12 | - `response_split`: The text to use as separator when cutting real response from the model output. 13 | 14 | No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest. 15 | 16 | ## Example template 17 | 18 | The default template, used unless otherwise specified, is `alpaca.json` 19 | 20 | ```json 21 | { 22 | "description": "Template used by Alpaca-LoRA.", 23 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 24 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 25 | "response_split": "### Response:" 26 | } 27 | 28 | ``` 29 | 30 | ## Current templates 31 | 32 | ### alpaca 33 | 34 | Default template used for generic LoRA fine tunes so far. 35 | 36 | ### alpaca_legacy 37 | 38 | Legacy template used by the original alpaca repo, with no `\n` after the response field. Kept for reference and experiments. 39 | 40 | ### alpaca_short 41 | 42 | A trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome. 43 | 44 | ### vigogne 45 | 46 | The default alpaca template, translated to french. This template was used to train the "Vigogne" LoRA and is to be used to query it, or for extra fine tuning. 47 | -------------------------------------------------------------------------------- /src/LoRA/templates/med_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Template used by Med Instruction Tuning", 3 | "prompt_input": "{instruction}\n### 回答:\n", 4 | "prompt_no_input": "{instruction}\n### 回答:\n", 5 | "response_split": "### 回答:" 6 | } -------------------------------------------------------------------------------- /src/LoRA/utils/README.md: -------------------------------------------------------------------------------- 1 | # Directory for helpers modules 2 | 3 | ## prompter.py 4 | 5 | Prompter class, a template manager. 6 | 7 | `from utils.prompter import Prompter` -------------------------------------------------------------------------------- /src/LoRA/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/williamliujl/CMExam/fadb22c89beb1b7115dc36460ba792eb96b7b972/src/LoRA/utils/__init__.py -------------------------------------------------------------------------------- /src/LoRA/utils/data_format_transform.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/5/17 01:01 3 | # @Author : Peilin Zhou 4 | # @FileName: data_format_transform.py 5 | # @Software: PyCharm 6 | # @E-mail : zhoupl@pku.edu.cn 7 | import json 8 | import os 9 | import argparse 10 | 11 | def filter_and_convert(input_file, target_id, sample=None): 12 | filtered_data = [] 13 | target_id = str(target_id) 14 | 15 | output_file_name = os.path.splitext(input_file)[0] 16 | if target_id is not None: 17 | output_file_name += '_' + str(target_id) 18 | else: 19 | output_file_name += '_all' 20 | 21 | with open(input_file, 'r', encoding='utf-8') as f: 22 | for line in f: 23 | data = json.loads(line) 24 | if target_id is None or target_id=='all' or data['id'] == target_id: 25 | filtered_data.append({ 26 | 'instruction': data['prompt'], 27 | 'input': '', 28 | 'output': data['completion'] 29 | }) 30 | 31 | output_file = output_file_name + '.json' 32 | 33 | with open(output_file, 'w', encoding='utf-8') as f: 34 | if sample: 35 | for data in filtered_data[:sample]: 36 | f.write(json.dumps(data, ensure_ascii=False) + '\n') 37 | else: 38 | for data in filtered_data: 39 | f.write(json.dumps(data, ensure_ascii=False) + '\n') 40 | 41 | print(f"Filtered file is saved to:{output_file}") 42 | return output_file 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description='Filter and convert JSON file.') 46 | parser.add_argument('input_file', type=str, help='path to the input JSON file', default='data/train_prompt.json') 47 | parser.add_argument('target_id', type=str, nargs='?', default=None, help='target ID for filtering (optional)') 48 | args = parser.parse_args() 49 | 50 | input_file_path = args.input_file 51 | target_id = args.target_id 52 | 53 | filter_and_convert(input_file_path, target_id) -------------------------------------------------------------------------------- /src/LoRA/utils/prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | A dedicated helper to manage templates and prompt building. 3 | """ 4 | 5 | import json 6 | import os.path as osp 7 | from typing import Union 8 | 9 | 10 | class Prompter(object): 11 | __slots__ = ("template", "_verbose") 12 | 13 | def __init__(self, template_name: str = "", verbose: bool = False): 14 | self._verbose = verbose 15 | if not template_name: 16 | # Enforce the default here, so the constructor can be called with '' and will not break. 17 | template_name = "alpaca" 18 | file_name = osp.join("./templates", f"{template_name}.json") 19 | if not osp.exists(file_name): 20 | raise ValueError(f"Can't read {file_name}") 21 | with open(file_name) as fp: 22 | self.template = json.load(fp) 23 | if self._verbose: 24 | print( 25 | f"Using prompt template {template_name}: {self.template['description']}" 26 | ) 27 | 28 | def generate_prompt( 29 | self, 30 | instruction: str, 31 | input: Union[None, str] = None, 32 | label: Union[None, str] = None, 33 | ) -> str: 34 | # returns the full prompt from instruction and optional input 35 | # if a label (=response, =output) is provided, it's also appended. 36 | if input: 37 | res = self.template["prompt_input"].format( 38 | instruction=instruction, input=input 39 | ) 40 | else: 41 | res = self.template["prompt_no_input"].format( 42 | instruction=instruction 43 | ) 44 | if label: 45 | res = f"{res}{label}" 46 | if self._verbose: 47 | print(res) 48 | return res 49 | 50 | def get_response(self, output: str) -> str: 51 | return output.split(self.template["response_split"])[1].strip() 52 | -------------------------------------------------------------------------------- /src/evaluation/evaluate/bleu.py: -------------------------------------------------------------------------------- 1 | """ 2 | Borrowed from https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py 3 | 4 | Python implementation of BLEU and smooth-BLEU. 5 | 6 | This module provides a Python implementation of BLEU and smooth-BLEU. 7 | Smooth BLEU is computed following the method outlined in the paper: 8 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 9 | evaluation metrics for machine translation. COLING 2004. 10 | """ 11 | 12 | import collections 13 | import math 14 | 15 | 16 | def _get_ngrams(segment, max_order): 17 | """Extracts all n-grams upto a given maximum order from an input segment. 18 | 19 | Args: 20 | segment: text segment from which n-grams will be extracted. 21 | max_order: maximum length in tokens of the n-grams returned by this 22 | methods. 23 | 24 | Returns: 25 | The Counter containing all n-grams upto max_order in segment 26 | with a count of how many times each n-gram occurred. 27 | """ 28 | ngram_counts = collections.Counter() 29 | for order in range(1, max_order + 1): 30 | for i in range(0, len(segment) - order + 1): 31 | ngram = tuple(segment[i:i+order]) 32 | ngram_counts[ngram] += 1 33 | return ngram_counts 34 | 35 | 36 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 37 | smooth=False): 38 | """Computes BLEU score of translated segments against one or more references. 39 | 40 | Args: 41 | reference_corpus: list of lists of references for each translation. Each 42 | reference should be tokenized into a list of tokens. 43 | translation_corpus: list of translations to score. Each translation 44 | should be tokenized into a list of tokens. 45 | max_order: Maximum n-gram order to use when computing BLEU score. 46 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 47 | 48 | Returns: 49 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 50 | precisions and brevity penalty. 51 | """ 52 | matches_by_order = [0] * max_order 53 | possible_matches_by_order = [0] * max_order 54 | reference_length = 0 55 | translation_length = 0 56 | for (references, translation) in zip(reference_corpus, 57 | translation_corpus): 58 | reference_length += min(len(r) for r in references) 59 | translation_length += len(translation) 60 | 61 | merged_ref_ngram_counts = collections.Counter() 62 | for reference in references: 63 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 64 | translation_ngram_counts = _get_ngrams(translation, max_order) 65 | overlap = translation_ngram_counts & merged_ref_ngram_counts 66 | for ngram in overlap: 67 | matches_by_order[len(ngram)-1] += overlap[ngram] 68 | for order in range(1, max_order+1): 69 | possible_matches = len(translation) - order + 1 70 | if possible_matches > 0: 71 | possible_matches_by_order[order-1] += possible_matches 72 | 73 | precisions = [0] * max_order 74 | for i in range(0, max_order): 75 | if smooth: 76 | precisions[i] = ((matches_by_order[i] + 1.) / 77 | (possible_matches_by_order[i] + 1.)) 78 | else: 79 | if possible_matches_by_order[i] > 0: 80 | precisions[i] = (float(matches_by_order[i]) / 81 | possible_matches_by_order[i]) 82 | else: 83 | precisions[i] = 0.0 84 | 85 | if min(precisions) > 0: 86 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 87 | geo_mean = math.exp(p_log_sum) 88 | else: 89 | geo_mean = 0 90 | 91 | ratio = float(translation_length) / reference_length 92 | 93 | if ratio > 1.0: 94 | bp = 1. 95 | else: 96 | bp = math.exp(1 - 1. / ratio) 97 | 98 | bleu = geo_mean * bp 99 | 100 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 101 | -------------------------------------------------------------------------------- /src/evaluation/evaluate/metrics4rec.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import math 4 | import numpy as np 5 | import heapq 6 | 7 | 8 | def evaluate_old(predict, groundtruth, topk=10): 9 | """[Deprecated] Compute metrics for predicted recommendations. 10 | Args: 11 | predict: a dict with key = and value = 12 | groundtruth: a dict with key = and value = . 13 | Returns: 14 | Dict of metrics. 15 | """ 16 | invalid_users = [] 17 | 18 | # Compute metrics 19 | precisions, recalls, ndcgs, hits = [], [], [], [] 20 | for uid in groundtruth: 21 | if uid not in predict or len(predict[uid]) < topk: 22 | invalid_users.append(uid) 23 | continue 24 | pred_list, rel_set = predict[uid][:topk], groundtruth[uid] 25 | if len(pred_list) == 0: 26 | continue 27 | 28 | dcg = 0.0 29 | hit_num = 0.0 30 | for i in range(len(pred_list)): 31 | if pred_list[i] in rel_set: 32 | dcg += 1.0 / (math.log(i + 2) / math.log(2)) 33 | hit_num += 1 34 | # idcg 35 | idcg = 0.0 36 | for i in range(min(len(rel_set), len(pred_list))): 37 | idcg += 1.0 / (math.log(i + 2) / math.log(2)) 38 | ndcg = dcg / idcg 39 | recall = hit_num / len(rel_set) 40 | precision = hit_num / len(pred_list) 41 | hit = 1.0 if hit_num > 0.0 else 0.0 42 | 43 | ndcgs.append(ndcg) 44 | recalls.append(recall) 45 | precisions.append(precision) 46 | hits.append(hit) 47 | 48 | avg_precision = np.mean(precisions) 49 | avg_recall = np.mean(recalls) 50 | avg_ndcg = np.mean(ndcgs) 51 | avg_hit = np.mean(hits) 52 | msg = "NDCG={:.4f} | Recall={:.4f} | HR={:.4f} | Precision={:.4f} | Invalid users={}".format( 53 | avg_ndcg, avg_recall, avg_hit, avg_precision, len(invalid_users) 54 | ) 55 | print(msg) 56 | return msg 57 | 58 | 59 | def recall_at_k(r, k, all_pos_num): 60 | r = np.asarray(r)[:k] 61 | return np.sum(r) / all_pos_num 62 | 63 | 64 | def hit_at_k(r, k): 65 | r = np.asarray(r)[:k] 66 | if np.sum(r) > 0: 67 | return 1.0 68 | else: 69 | return 0.0 70 | 71 | 72 | def mean_reciprocal_rank(rs): 73 | """Score is reciprocal of the rank of the first relevant item 74 | First element is 'rank 1'. Relevance is binary (nonzero is relevant). 75 | Example from http://en.wikipedia.org/wiki/Mean_reciprocal_rank 76 | >>> rs = [[0, 0, 1], [0, 1, 0], [1, 0, 0]] 77 | >>> mean_reciprocal_rank(rs) 78 | 0.61111111111111105 79 | >>> rs = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]]) 80 | >>> mean_reciprocal_rank(rs) 81 | 0.5 82 | >>> rs = [[0, 0, 0, 1], [1, 0, 0], [1, 0, 0]] 83 | >>> mean_reciprocal_rank(rs) 84 | 0.75 85 | Args: 86 | rs: Iterator of relevance scores (list or numpy) in rank order 87 | (first element is the first item) 88 | Returns: 89 | Mean reciprocal rank 90 | """ 91 | rs = (np.asarray(r).nonzero()[0] for r in rs) 92 | return np.mean([1.0 / (r[0] + 1) if r.size else 0.0 for r in rs]) 93 | 94 | 95 | def r_precision(r): 96 | """Score is precision after all relevant documents have been retrieved 97 | Relevance is binary (nonzero is relevant). 98 | >>> r = [0, 0, 1] 99 | >>> r_precision(r) 100 | 0.33333333333333331 101 | >>> r = [0, 1, 0] 102 | >>> r_precision(r) 103 | 0.5 104 | >>> r = [1, 0, 0] 105 | >>> r_precision(r) 106 | 1.0 107 | Args: 108 | r: Relevance scores (list or numpy) in rank order 109 | (first element is the first item) 110 | Returns: 111 | R Precision 112 | """ 113 | r = np.asarray(r) != 0 114 | z = r.nonzero()[0] 115 | if not z.size: 116 | return 0.0 117 | return np.mean(r[: z[-1] + 1]) 118 | 119 | 120 | def precision_at_k(r, k): 121 | """Score is precision @ k 122 | Relevance is binary (nonzero is relevant). 123 | >>> r = [0, 0, 1] 124 | >>> precision_at_k(r, 1) 125 | 0.0 126 | >>> precision_at_k(r, 2) 127 | 0.0 128 | >>> precision_at_k(r, 3) 129 | 0.33333333333333331 130 | >>> precision_at_k(r, 4) 131 | Traceback (most recent call last): 132 | File "", line 1, in ? 133 | ValueError: Relevance score length < k 134 | Args: 135 | r: Relevance scores (list or numpy) in rank order 136 | (first element is the first item) 137 | Returns: 138 | Precision @ k 139 | Raises: 140 | ValueError: len(r) must be >= k 141 | """ 142 | assert k >= 1 143 | r = np.asarray(r)[:k] != 0 144 | if r.size != k: 145 | raise ValueError("Relevance score length < k") 146 | return np.mean(r) 147 | 148 | 149 | def average_precision(r): 150 | """Score is average precision (area under PR curve) 151 | Relevance is binary (nonzero is relevant). 152 | >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1] 153 | >>> delta_r = 1. / sum(r) 154 | >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in enumerate(r) if y]) 155 | 0.7833333333333333 156 | >>> average_precision(r) 157 | 0.78333333333333333 158 | Args: 159 | r: Relevance scores (list or numpy) in rank order 160 | (first element is the first item) 161 | Returns: 162 | Average precision 163 | """ 164 | r = np.asarray(r) != 0 165 | out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]] 166 | if not out: 167 | return 0.0 168 | return np.mean(out) 169 | 170 | 171 | def mean_average_precision(rs): 172 | """Score is mean average precision 173 | Relevance is binary (nonzero is relevant). 174 | >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1]] 175 | >>> mean_average_precision(rs) 176 | 0.78333333333333333 177 | >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1], [0]] 178 | >>> mean_average_precision(rs) 179 | 0.39166666666666666 180 | Args: 181 | rs: Iterator of relevance scores (list or numpy) in rank order 182 | (first element is the first item) 183 | Returns: 184 | Mean average precision 185 | """ 186 | return np.mean([average_precision(r) for r in rs]) 187 | 188 | 189 | def dcg_at_k(r, k, method=1): 190 | """Score is discounted cumulative gain (dcg) 191 | Relevance is positive real values. Can use binary 192 | as the previous methods. 193 | Example from 194 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 195 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0] 196 | >>> dcg_at_k(r, 1) 197 | 3.0 198 | >>> dcg_at_k(r, 1, method=1) 199 | 3.0 200 | >>> dcg_at_k(r, 2) 201 | 5.0 202 | >>> dcg_at_k(r, 2, method=1) 203 | 4.2618595071429155 204 | >>> dcg_at_k(r, 10) 205 | 9.6051177391888114 206 | >>> dcg_at_k(r, 11) 207 | 9.6051177391888114 208 | Args: 209 | r: Relevance scores (list or numpy) in rank order 210 | (first element is the first item) 211 | k: Number of results to consider 212 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...] 213 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...] 214 | Returns: 215 | Discounted cumulative gain 216 | """ 217 | r = np.asfarray(r)[:k] 218 | if r.size: 219 | if method == 0: 220 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 221 | elif method == 1: 222 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 223 | else: 224 | raise ValueError("method must be 0 or 1.") 225 | return 0.0 226 | 227 | 228 | def ndcg_at_k(r, k, method=1): 229 | """Score is normalized discounted cumulative gain (ndcg) 230 | Relevance is positive real values. Can use binary 231 | as the previous methods. 232 | Example from 233 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 234 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0] 235 | >>> ndcg_at_k(r, 1) 236 | 1.0 237 | >>> r = [2, 1, 2, 0] 238 | >>> ndcg_at_k(r, 4) 239 | 0.9203032077642922 240 | >>> ndcg_at_k(r, 4, method=1) 241 | 0.96519546960144276 242 | >>> ndcg_at_k([0], 1) 243 | 0.0 244 | >>> ndcg_at_k([1], 2) 245 | 1.0 246 | Args: 247 | r: Relevance scores (list or numpy) in rank order 248 | (first element is the first item) 249 | k: Number of results to consider 250 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...] 251 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...] 252 | Returns: 253 | Normalized discounted cumulative gain 254 | """ 255 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method) 256 | if not dcg_max: 257 | return 0.0 258 | return dcg_at_k(r, k, method) / dcg_max 259 | 260 | 261 | def evaluate_once(topk_preds, groundtruth): 262 | """Evaluate one user performance. 263 | Args: 264 | topk_preds: list of . length of the list is topK. 265 | groundtruth: list of . 266 | Returns: 267 | dict of metrics. 268 | """ 269 | gt_set = set(groundtruth) 270 | topk = len(topk_preds) 271 | rel = [] 272 | for iid in topk_preds: 273 | if iid in gt_set: 274 | rel.append(1) 275 | else: 276 | rel.append(0) 277 | return { 278 | "precision@k": precision_at_k(rel, topk), 279 | "recall@k": recall_at_k(rel, topk, len(gt_set)), 280 | "ndcg@k": ndcg_at_k(rel, topk, 1), 281 | "hit@k": hit_at_k(rel, topk), 282 | "ap": average_precision(rel), 283 | "rel": rel, 284 | } 285 | 286 | 287 | def evaluate_all(user_item_scores, groudtruth, topk=10): 288 | """Evaluate all user-items performance. 289 | Args: 290 | user_item_scores: dict with key = , value = . 291 | Make sure larger score means better recommendation. 292 | groudtruth: dict with key = , value = list of . 293 | topk: int 294 | Returns: 295 | """ 296 | avg_prec, avg_recall, avg_ndcg, avg_hit = 0.0, 0.0, 0.0, 0.0 297 | rs = [] 298 | cnt = 0 299 | for uid in user_item_scores: 300 | # [Important] Use shuffle to break ties!!! 301 | ui_scores = list(user_item_scores[uid].items()) 302 | np.random.shuffle(ui_scores) # break ties 303 | # topk_preds = heapq.nlargest(topk, user_item_scores[uid], key=user_item_scores[uid].get) # list of k 304 | topk_preds = heapq.nlargest(topk, ui_scores, key=lambda x: x[1]) # list of k tuples 305 | topk_preds = [x[0] for x in topk_preds] # list of k 306 | # print(topk_preds, groudtruth[uid]) 307 | result = evaluate_once(topk_preds, groudtruth[uid]) 308 | avg_prec += result["precision@k"] 309 | avg_recall += result["recall@k"] 310 | avg_ndcg += result["ndcg@k"] 311 | avg_hit += result["hit@k"] 312 | rs.append(result["rel"]) 313 | cnt += 1 314 | 315 | # [CAVEAT] Following code calculates metrics for each gt item. 316 | # for iid in groudtruth[uid]: 317 | # result = evaluate_once(topk_preds, [iid]) 318 | # avg_prec += result["precision@k"] 319 | # avg_recall += result["recall@k"] 320 | # avg_ndcg += result["ndcg@k"] 321 | # avg_hit += result["hit@k"] 322 | # rs.append(result["rel"]) 323 | # cnt += 1 324 | 325 | avg_prec = avg_prec / cnt 326 | avg_recall = avg_recall / cnt 327 | avg_ndcg = avg_ndcg / cnt 328 | avg_hit = avg_hit / cnt 329 | map_ = mean_average_precision(rs) 330 | mrr = mean_reciprocal_rank(rs) 331 | msg = "\nNDCG@{}\tRec@{}\tHits@{}\tPrec@{}\tMAP@{}\tMRR@{}".format(topk, topk, topk, topk, topk, topk) 332 | msg += "\n{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}".format(avg_ndcg, avg_recall, avg_hit, avg_prec, map_, mrr) 333 | # msg = "NDCG@{}\tRec@{}\tMAP@{}".format(topk, topk, topk) 334 | # msg += "\n{:.4f}\t{:.4f}\t{:.4f}".format(avg_ndcg, avg_recall, map) 335 | print(msg) 336 | res = { 337 | 'ndcg': avg_ndcg, 338 | 'map': map_, 339 | 'recall': avg_recall, 340 | 'precision': avg_prec, 341 | 'mrr': mrr, 342 | 'hit': avg_hit, 343 | } 344 | return msg, res 345 | 346 | 347 | def main(): 348 | ui_scores = { 349 | 1: {11: 3, 12: 4, 13: 5, 14: 6, 15: 7}, 350 | # 2: {11: 3, 12: 4, 13: 5, 14: 6, 15: 7}, 351 | # 3: {11: 3, 12: 4, 13: 5, 14: 6, 15: 7}, 352 | # 4: {11: 3, 12: 4, 13: 5, 14: 6, 15: 7}, 353 | # 5: {11: 3, 12: 4, 13: 5, 14: 6, 15: 7}, 354 | } 355 | gt = { 356 | 1: [11, 15], 357 | # 2: [12, 13], 358 | # 3: [11, 14], 359 | # 4: [12, 15], 360 | # 5: [11], 361 | } 362 | evaluate_all(ui_scores, gt, 5) 363 | 364 | # pred = {} 365 | # for uid in ui_scores: 366 | # pred[uid] = heapq.nlargest(3, ui_scores[uid], key=ui_scores[uid].get) 367 | # evaluate_old(pred, gt, 3) 368 | 369 | 370 | if __name__ == "__main__": 371 | main() 372 | -------------------------------------------------------------------------------- /src/evaluation/evaluate/rouge.py: -------------------------------------------------------------------------------- 1 | """ 2 | Borrowed from https://github.com/tensorflow/nmt/blob/master/nmt/scripts/rouge.py 3 | 4 | ROUGE metric implementation. 5 | 6 | Copy from tf_seq2seq/seq2seq/metrics/rouge.py. 7 | This is a modified and slightly extended verison of 8 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py. 9 | """ 10 | 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | from __future__ import unicode_literals 15 | 16 | import itertools 17 | import numpy as np 18 | 19 | #pylint: disable=C0103 20 | 21 | 22 | def _get_ngrams(n, text): 23 | """Calcualtes n-grams. 24 | 25 | Args: 26 | n: which n-grams to calculate 27 | text: An array of tokens 28 | 29 | Returns: 30 | A set of n-grams 31 | """ 32 | ngram_set = set() 33 | text_length = len(text) 34 | max_index_ngram_start = text_length - n 35 | for i in range(max_index_ngram_start + 1): 36 | ngram_set.add(tuple(text[i:i + n])) 37 | return ngram_set 38 | 39 | 40 | def _split_into_words(sentences): 41 | """Splits multiple sentences into words and flattens the result""" 42 | return list(itertools.chain(*[_.split(" ") for _ in sentences])) 43 | 44 | 45 | def _get_word_ngrams(n, sentences): 46 | """Calculates word n-grams for multiple sentences. 47 | """ 48 | assert len(sentences) > 0 49 | assert n > 0 50 | 51 | words = _split_into_words(sentences) 52 | return _get_ngrams(n, words) 53 | 54 | 55 | def _len_lcs(x, y): 56 | """ 57 | Returns the length of the Longest Common Subsequence between sequences x 58 | and y. 59 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 60 | 61 | Args: 62 | x: sequence of words 63 | y: sequence of words 64 | 65 | Returns 66 | integer: Length of LCS between x and y 67 | """ 68 | table = _lcs(x, y) 69 | n, m = len(x), len(y) 70 | return table[n, m] 71 | 72 | 73 | def _lcs(x, y): 74 | """ 75 | Computes the length of the longest common subsequence (lcs) between two 76 | strings. The implementation below uses a DP programming algorithm and runs 77 | in O(nm) time where n = len(x) and m = len(y). 78 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 79 | 80 | Args: 81 | x: collection of words 82 | y: collection of words 83 | 84 | Returns: 85 | Table of dictionary of coord and len lcs 86 | """ 87 | n, m = len(x), len(y) 88 | table = dict() 89 | for i in range(n + 1): 90 | for j in range(m + 1): 91 | if i == 0 or j == 0: 92 | table[i, j] = 0 93 | elif x[i - 1] == y[j - 1]: 94 | table[i, j] = table[i - 1, j - 1] + 1 95 | else: 96 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 97 | return table 98 | 99 | 100 | def _recon_lcs(x, y): 101 | """ 102 | Returns the Longest Subsequence between x and y. 103 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 104 | 105 | Args: 106 | x: sequence of words 107 | y: sequence of words 108 | 109 | Returns: 110 | sequence: LCS of x and y 111 | """ 112 | i, j = len(x), len(y) 113 | table = _lcs(x, y) 114 | 115 | def _recon(i, j): 116 | """private recon calculation""" 117 | if i == 0 or j == 0: 118 | return [] 119 | elif x[i - 1] == y[j - 1]: 120 | return _recon(i - 1, j - 1) + [(x[i - 1], i)] 121 | elif table[i - 1, j] > table[i, j - 1]: 122 | return _recon(i - 1, j) 123 | else: 124 | return _recon(i, j - 1) 125 | 126 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j))) 127 | return recon_tuple 128 | 129 | 130 | def rouge_n(evaluated_sentences, reference_sentences, n=2): 131 | """ 132 | Computes ROUGE-N of two text collections of sentences. 133 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ 134 | papers/rouge-working-note-v1.3.1.pdf 135 | 136 | Args: 137 | evaluated_sentences: The sentences that have been picked by the summarizer 138 | reference_sentences: The sentences from the referene set 139 | n: Size of ngram. Defaults to 2. 140 | 141 | Returns: 142 | A tuple (f1, precision, recall) for ROUGE-N 143 | 144 | Raises: 145 | ValueError: raises exception if a param has len <= 0 146 | """ 147 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 148 | raise ValueError("Collections must contain at least 1 sentence.") 149 | 150 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences) 151 | reference_ngrams = _get_word_ngrams(n, reference_sentences) 152 | reference_count = len(reference_ngrams) 153 | evaluated_count = len(evaluated_ngrams) 154 | 155 | # Gets the overlapping ngrams between evaluated and reference 156 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 157 | overlapping_count = len(overlapping_ngrams) 158 | 159 | # Handle edge case. This isn't mathematically correct, but it's good enough 160 | if evaluated_count == 0: 161 | precision = 0.0 162 | else: 163 | precision = overlapping_count / evaluated_count 164 | 165 | if reference_count == 0: 166 | recall = 0.0 167 | else: 168 | recall = overlapping_count / reference_count 169 | 170 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 171 | 172 | # return overlapping_count / reference_count 173 | return f1_score, precision, recall 174 | 175 | 176 | def _f_p_r_lcs(llcs, m, n): 177 | """ 178 | Computes the LCS-based F-measure score 179 | Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 180 | rouge-working-note-v1.3.1.pdf 181 | 182 | Args: 183 | llcs: Length of LCS 184 | m: number of words in reference summary 185 | n: number of words in candidate summary 186 | 187 | Returns: 188 | Float. LCS-based F-measure score 189 | """ 190 | r_lcs = llcs / m 191 | p_lcs = llcs / n 192 | beta = p_lcs / (r_lcs + 1e-12) 193 | num = (1 + (beta**2)) * r_lcs * p_lcs 194 | denom = r_lcs + ((beta**2) * p_lcs) 195 | f_lcs = num / (denom + 1e-12) 196 | return f_lcs, p_lcs, r_lcs 197 | 198 | 199 | def rouge_l_sentence_level(evaluated_sentences, reference_sentences): 200 | """ 201 | Computes ROUGE-L (sentence level) of two text collections of sentences. 202 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 203 | rouge-working-note-v1.3.1.pdf 204 | 205 | Calculated according to: 206 | R_lcs = LCS(X,Y)/m 207 | P_lcs = LCS(X,Y)/n 208 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 209 | 210 | where: 211 | X = reference summary 212 | Y = Candidate summary 213 | m = length of reference summary 214 | n = length of candidate summary 215 | 216 | Args: 217 | evaluated_sentences: The sentences that have been picked by the summarizer 218 | reference_sentences: The sentences from the referene set 219 | 220 | Returns: 221 | A float: F_lcs 222 | 223 | Raises: 224 | ValueError: raises exception if a param has len <= 0 225 | """ 226 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 227 | raise ValueError("Collections must contain at least 1 sentence.") 228 | reference_words = _split_into_words(reference_sentences) 229 | evaluated_words = _split_into_words(evaluated_sentences) 230 | m = len(reference_words) 231 | n = len(evaluated_words) 232 | lcs = _len_lcs(evaluated_words, reference_words) 233 | return _f_p_r_lcs(lcs, m, n) 234 | 235 | 236 | def _union_lcs(evaluated_sentences, reference_sentence): 237 | """ 238 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common 239 | subsequence between reference sentence ri and candidate summary C. For example 240 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 and 241 | c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 is 242 | "w1 w2" and the longest common subsequence of r_i and c2 is "w1 w3 w5". The 243 | union longest common subsequence of r_i, c1, and c2 is "w1 w2 w3 w5" and 244 | LCS_u(r_i, C) = 4/5. 245 | 246 | Args: 247 | evaluated_sentences: The sentences that have been picked by the summarizer 248 | reference_sentence: One of the sentences in the reference summaries 249 | 250 | Returns: 251 | float: LCS_u(r_i, C) 252 | 253 | ValueError: 254 | Raises exception if a param has len <= 0 255 | """ 256 | if len(evaluated_sentences) <= 0: 257 | raise ValueError("Collections must contain at least 1 sentence.") 258 | 259 | lcs_union = set() 260 | reference_words = _split_into_words([reference_sentence]) 261 | combined_lcs_length = 0 262 | for eval_s in evaluated_sentences: 263 | evaluated_words = _split_into_words([eval_s]) 264 | lcs = set(_recon_lcs(reference_words, evaluated_words)) 265 | combined_lcs_length += len(lcs) 266 | lcs_union = lcs_union.union(lcs) 267 | 268 | union_lcs_count = len(lcs_union) 269 | union_lcs_value = union_lcs_count / combined_lcs_length 270 | return union_lcs_value 271 | 272 | 273 | def rouge_l_summary_level(evaluated_sentences, reference_sentences): 274 | """ 275 | Computes ROUGE-L (summary level) of two text collections of sentences. 276 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 277 | rouge-working-note-v1.3.1.pdf 278 | 279 | Calculated according to: 280 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m 281 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n 282 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 283 | 284 | where: 285 | SUM(i,u) = SUM from i through u 286 | u = number of sentences in reference summary 287 | C = Candidate summary made up of v sentences 288 | m = number of words in reference summary 289 | n = number of words in candidate summary 290 | 291 | Args: 292 | evaluated_sentences: The sentences that have been picked by the summarizer 293 | reference_sentence: One of the sentences in the reference summaries 294 | 295 | Returns: 296 | A float: F_lcs 297 | 298 | Raises: 299 | ValueError: raises exception if a param has len <= 0 300 | """ 301 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 302 | raise ValueError("Collections must contain at least 1 sentence.") 303 | 304 | # total number of words in reference sentences 305 | m = len(_split_into_words(reference_sentences)) 306 | 307 | # total number of words in evaluated sentences 308 | n = len(_split_into_words(evaluated_sentences)) 309 | 310 | union_lcs_sum_across_all_references = 0 311 | for ref_s in reference_sentences: 312 | union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences, 313 | ref_s) 314 | return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n) 315 | 316 | 317 | def rouge(hypotheses, references): 318 | """Calculates average rouge scores for a list of hypotheses and 319 | references""" 320 | 321 | # Filter out hyps that are of 0 length 322 | # hyps_and_refs = zip(hypotheses, references) 323 | # hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0] 324 | # hypotheses, references = zip(*hyps_and_refs) 325 | 326 | # Calculate ROUGE-1 F1, precision, recall scores 327 | rouge_1 = [ 328 | rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references) 329 | ] 330 | rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1)) 331 | 332 | # Calculate ROUGE-2 F1, precision, recall scores 333 | rouge_2 = [ 334 | rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references) 335 | ] 336 | rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2)) 337 | 338 | # Calculate ROUGE-L F1, precision, recall scores 339 | rouge_l = [ 340 | rouge_l_sentence_level([hyp], [ref]) 341 | for hyp, ref in zip(hypotheses, references) 342 | ] 343 | rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l)) 344 | 345 | return { 346 | "rouge_1/f_score": rouge_1_f, 347 | "rouge_1/r_score": rouge_1_r, 348 | "rouge_1/p_score": rouge_1_p, 349 | "rouge_2/f_score": rouge_2_f, 350 | "rouge_2/r_score": rouge_2_r, 351 | "rouge_2/p_score": rouge_2_p, 352 | "rouge_l/f_score": rouge_l_f, 353 | "rouge_l/r_score": rouge_l_r, 354 | "rouge_l/p_score": rouge_l_p, 355 | } 356 | -------------------------------------------------------------------------------- /src/evaluation/evaluate/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import heapq 5 | import random 6 | import pickle 7 | import datetime 8 | from .rouge import rouge 9 | from .bleu import compute_bleu 10 | 11 | 12 | def rouge_score(references, generated): 13 | """both are a list of strings""" 14 | score = rouge(generated, references) 15 | rouge_s = {k: (v * 100) for (k, v) in score.items()} 16 | ''' 17 | "rouge_1/f_score": rouge_1_f, 18 | "rouge_1/r_score": rouge_1_r, 19 | "rouge_1/p_score": rouge_1_p, 20 | "rouge_2/f_score": rouge_2_f, 21 | "rouge_2/r_score": rouge_2_r, 22 | "rouge_2/p_score": rouge_2_p, 23 | "rouge_l/f_score": rouge_l_f, 24 | "rouge_l/r_score": rouge_l_r, 25 | "rouge_l/p_score": rouge_l_p, 26 | ''' 27 | return rouge_s 28 | 29 | 30 | def bleu_score(references, generated, n_gram=4, smooth=False): 31 | """a list of lists of tokens""" 32 | formatted_ref = [[ref] for ref in references] 33 | bleu_s, _, _, _, _, _ = compute_bleu(formatted_ref, generated, n_gram, smooth) 34 | return bleu_s * 100 35 | 36 | 37 | def two_seq_same(sa, sb): 38 | if len(sa) != len(sb): 39 | return False 40 | for (wa, wb) in zip(sa, sb): 41 | if wa != wb: 42 | return False 43 | return True 44 | 45 | 46 | def unique_sentence_percent(sequence_batch): 47 | unique_seq = [] 48 | for seq in sequence_batch: 49 | count = 0 50 | for uni_seq in unique_seq: 51 | if two_seq_same(seq, uni_seq): 52 | count += 1 53 | break 54 | if count == 0: 55 | unique_seq.append(seq) 56 | 57 | return len(unique_seq) / len(sequence_batch), len(unique_seq) 58 | 59 | 60 | def feature_detect(seq_batch, feature_set): 61 | feature_batch = [] 62 | for ids in seq_batch: 63 | feature_list = [] 64 | for i in ids: 65 | if i in feature_set: 66 | feature_list.append(i) 67 | feature_batch.append(set(feature_list)) 68 | 69 | return feature_batch 70 | 71 | 72 | def feature_matching_ratio(feature_batch, test_feature): 73 | count = 0 74 | for (fea_set, fea) in zip(feature_batch, test_feature): 75 | if fea in fea_set: 76 | count += 1 77 | 78 | return count / len(feature_batch) 79 | 80 | 81 | def feature_coverage_ratio(feature_batch, feature_set): 82 | features = set() 83 | for fb in feature_batch: 84 | features = features | fb 85 | 86 | return len(features) / len(feature_set) 87 | 88 | 89 | def feature_diversity(feature_batch): 90 | list_len = len(feature_batch) 91 | 92 | total_count = 0 93 | for i, x in enumerate(feature_batch): 94 | for j in range(i + 1, list_len): 95 | y = feature_batch[j] 96 | total_count += len(x & y) 97 | 98 | denominator = list_len * (list_len - 1) / 2 99 | return total_count / denominator 100 | 101 | 102 | def mean_absolute_error(predicted, max_r, min_r, mae=True): 103 | total = 0 104 | for (r, p) in predicted: 105 | if p > max_r: 106 | p = max_r 107 | if p < min_r: 108 | p = min_r 109 | 110 | sub = p - r 111 | if mae: 112 | total += abs(sub) 113 | else: 114 | total += sub ** 2 115 | 116 | return total / len(predicted) 117 | 118 | 119 | def root_mean_square_error(predicted, max_r, min_r): 120 | mse = mean_absolute_error(predicted, max_r, min_r, False) 121 | return math.sqrt(mse) 122 | 123 | 124 | class WordDictionary: 125 | def __init__(self): 126 | self.idx2word = ['', '', '', ''] 127 | self.__predefine_num = len(self.idx2word) 128 | self.word2idx = {w: i for i, w in enumerate(self.idx2word)} 129 | self.__word2count = {} 130 | 131 | def add_sentence(self, sentence): 132 | for w in sentence.split(): 133 | self.add_word(w) 134 | 135 | def add_word(self, w): 136 | if w not in self.word2idx: 137 | self.word2idx[w] = len(self.idx2word) 138 | self.idx2word.append(w) 139 | self.__word2count[w] = 1 140 | else: 141 | self.__word2count[w] += 1 142 | 143 | def __len__(self): 144 | return len(self.idx2word) 145 | 146 | def keep_most_frequent(self, max_vocab_size=20000): 147 | if len(self.__word2count) > max_vocab_size: 148 | frequent_words = heapq.nlargest(max_vocab_size, self.__word2count, key=self.__word2count.get) 149 | self.idx2word = self.idx2word[:self.__predefine_num] + frequent_words 150 | self.word2idx = {w: i for i, w in enumerate(self.idx2word)} 151 | 152 | 153 | class EntityDictionary: 154 | def __init__(self): 155 | self.idx2entity = [] 156 | self.entity2idx = {} 157 | 158 | def add_entity(self, e): 159 | if e not in self.entity2idx: 160 | self.entity2idx[e] = len(self.idx2entity) 161 | self.idx2entity.append(e) 162 | 163 | def __len__(self): 164 | return len(self.idx2entity) 165 | 166 | 167 | class DataLoader: 168 | def __init__(self, data_path, index_dir, vocab_size): 169 | self.word_dict = WordDictionary() 170 | self.user_dict = EntityDictionary() 171 | self.item_dict = EntityDictionary() 172 | self.max_rating = float('-inf') 173 | self.min_rating = float('inf') 174 | self.initialize(data_path) 175 | self.word_dict.keep_most_frequent(vocab_size) 176 | self.__unk = self.word_dict.word2idx[''] 177 | self.feature_set = set() 178 | self.train, self.valid, self.test = self.load_data(data_path, index_dir) 179 | 180 | def initialize(self, data_path): 181 | assert os.path.exists(data_path) 182 | reviews = pickle.load(open(data_path, 'rb')) 183 | for review in reviews: 184 | self.user_dict.add_entity(review['user']) 185 | self.item_dict.add_entity(review['item']) 186 | (fea, adj, tem, sco) = review['template'] 187 | self.word_dict.add_sentence(tem) 188 | self.word_dict.add_word(fea) 189 | rating = review['rating'] 190 | if self.max_rating < rating: 191 | self.max_rating = rating 192 | if self.min_rating > rating: 193 | self.min_rating = rating 194 | 195 | def load_data(self, data_path, index_dir): 196 | data = [] 197 | reviews = pickle.load(open(data_path, 'rb')) 198 | for review in reviews: 199 | (fea, adj, tem, sco) = review['template'] 200 | data.append({'user': self.user_dict.entity2idx[review['user']], 201 | 'item': self.item_dict.entity2idx[review['item']], 202 | 'rating': review['rating'], 203 | 'text': self.seq2ids(tem), 204 | 'feature': self.word_dict.word2idx.get(fea, self.__unk)}) 205 | if fea in self.word_dict.word2idx: 206 | self.feature_set.add(fea) 207 | else: 208 | self.feature_set.add('') 209 | 210 | train_index, valid_index, test_index = self.load_index(index_dir) 211 | train, valid, test = [], [], [] 212 | for idx in train_index: 213 | train.append(data[idx]) 214 | for idx in valid_index: 215 | valid.append(data[idx]) 216 | for idx in test_index: 217 | test.append(data[idx]) 218 | return train, valid, test 219 | 220 | def seq2ids(self, seq): 221 | return [self.word_dict.word2idx.get(w, self.__unk) for w in seq.split()] 222 | 223 | def load_index(self, index_dir): 224 | assert os.path.exists(index_dir) 225 | with open(os.path.join(index_dir, 'train.index'), 'r') as f: 226 | train_index = [int(x) for x in f.readline().split(' ')] 227 | with open(os.path.join(index_dir, 'validation.index'), 'r') as f: 228 | valid_index = [int(x) for x in f.readline().split(' ')] 229 | with open(os.path.join(index_dir, 'test.index'), 'r') as f: 230 | test_index = [int(x) for x in f.readline().split(' ')] 231 | return train_index, valid_index, test_index 232 | 233 | 234 | def sentence_format(sentence, max_len, pad, bos, eos): 235 | length = len(sentence) 236 | if length >= max_len: 237 | return [bos] + sentence[:max_len] + [eos] 238 | else: 239 | return [bos] + sentence + [eos] + [pad] * (max_len - length) 240 | 241 | 242 | class Batchify: 243 | def __init__(self, data, word2idx, seq_len=15, batch_size=128, shuffle=False): 244 | bos = word2idx[''] 245 | eos = word2idx[''] 246 | pad = word2idx[''] 247 | u, i, r, t, f = [], [], [], [], [] 248 | for x in data: 249 | u.append(x['user']) 250 | i.append(x['item']) 251 | r.append(x['rating']) 252 | t.append(sentence_format(x['text'], seq_len, pad, bos, eos)) 253 | f.append([x['feature']]) 254 | 255 | self.user = torch.tensor(u, dtype=torch.int64).contiguous() 256 | self.item = torch.tensor(i, dtype=torch.int64).contiguous() 257 | self.rating = torch.tensor(r, dtype=torch.float).contiguous() 258 | self.seq = torch.tensor(t, dtype=torch.int64).contiguous() 259 | self.feature = torch.tensor(f, dtype=torch.int64).contiguous() 260 | self.shuffle = shuffle 261 | self.batch_size = batch_size 262 | self.sample_num = len(data) 263 | self.index_list = list(range(self.sample_num)) 264 | self.total_step = int(math.ceil(self.sample_num / self.batch_size)) 265 | self.step = 0 266 | 267 | def next_batch(self): 268 | if self.step == self.total_step: 269 | self.step = 0 270 | if self.shuffle: 271 | random.shuffle(self.index_list) 272 | 273 | start = self.step * self.batch_size 274 | offset = min(start + self.batch_size, self.sample_num) 275 | self.step += 1 276 | index = self.index_list[start:offset] 277 | user = self.user[index] # (batch_size,) 278 | item = self.item[index] 279 | rating = self.rating[index] 280 | seq = self.seq[index] # (batch_size, seq_len) 281 | feature = self.feature[index] # (batch_size, 1) 282 | return user, item, rating, seq, feature 283 | 284 | 285 | def now_time(): 286 | return '[' + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') + ']: ' 287 | 288 | 289 | def ids2tokens(ids, word2idx, idx2word): 290 | eos = word2idx[''] 291 | tokens = [] 292 | for i in ids: 293 | if i == eos: 294 | break 295 | tokens.append(idx2word[i]) 296 | return tokens 297 | -------------------------------------------------------------------------------- /src/evaluation/evaluate_chatglm_result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # -------------------------------------------- 4 | # @FileName: evaluate_chatglm_result.py 5 | # @Author: ljl 6 | # @Time: 2023/5/10 7 | # @Description: 8 | # -------------------------------------------- 9 | 10 | import os 11 | import re 12 | from transformers import AutoTokenizer, AutoModel 13 | import argparse 14 | import pandas as pd 15 | from tqdm import tqdm 16 | 17 | template_multi = "假设你是一位医疗行业专家,请回答下列问题。注意,该问题是多选题\n" \ 18 | "{}:\n{}\n" \ 19 | "注意,请给出两行,第一行只需要返回答案的英文选项,第二行进行简要的解释。输出格式限制为“答案:”,“解释:”" 20 | 21 | template_single = "返回限制:只返回两行。" \ 22 | "假设你是一位医疗行业专家,请回答下列问题,注意是单选题,只需要返回一个最合适的选项。\n" \ 23 | "{}:\n{}\n" \ 24 | "注意,结果只有两行,第一行只需要返回答案的英文选项(注意只需要返回一个最合适的答案),第二行进行简要的解释。输出格式限制为:“答案:”,“解释:”。\n" \ 25 | "注意,题目是单选题,若有多个合适的答案,只返回最准确的即可。" 26 | 27 | def prediction(args): 28 | # load model 29 | tokenizer = AutoTokenizer.from_pretrained(args.modelpath, trust_remote_code=True) 30 | model = AutoModel.from_pretrained(args.tokenizerpath, trust_remote_code=True).half().cuda() 31 | model = model.eval() 32 | 33 | def predict(data): 34 | results = [] 35 | for content in tqdm(data): 36 | try: 37 | response, history = model.chat(tokenizer, content, history=[]) 38 | except Exception as e: 39 | response = "" 40 | results.append(response) 41 | return results 42 | 43 | # load csv 44 | csv = pd.read_csv(args.filepath) 45 | questions = csv['Question'].values.tolist() 46 | options = csv['Options'].values.tolist() 47 | gt_answer = csv['Answer'].values.tolist() 48 | 49 | data = [] 50 | raw_results = [] 51 | for i in range(len(questions)): 52 | if len(gt_answer[i]) == 1: 53 | data.append(template_single.format(questions[i], options[i])) 54 | else: 55 | data.append(template_multi.format(questions[i], options[i])) 56 | 57 | raw_results.extend(predict(data)) 58 | predicted_answer = [] 59 | predicted_explanation = [] 60 | for single in raw_results: 61 | try: 62 | answer = re.findall(r"答案:(.*),", single)[0] 63 | exp = re.findall(r"解释:(.*)", single)[0] 64 | predicted_answer.append(answer) 65 | predicted_explanation.append(exp) 66 | except Exception as e: 67 | print(single, flush=True) 68 | predicted_answer.append("") 69 | predicted_explanation.append("") 70 | 71 | csv['raw_prediction'] = raw_results 72 | csv['predicted_answer'] = predicted_answer 73 | csv['predicted_explanation'] = predicted_explanation 74 | 75 | if not os.path.exists(args.savepath): 76 | os.mkdir(args.savepath) 77 | csv.to_csv(args.savepath, index=False) 78 | 79 | 80 | def evaluation(args): 81 | csv = pd.read_csv(args.savepath) 82 | 83 | gt_exp = csv['Explanation'].values.tolist() 84 | predict_exp = csv['predicted_explanation'].values.tolist() 85 | # process pd.na 86 | gt_exp = [item if not pd.isna(item) else "" for item in gt_exp] 87 | predict_exp = [item if not pd.isna(item) else "" for item in predict_exp] 88 | 89 | gt_answer = csv['Answer'].values.tolist() 90 | predict_answer = csv['predicted_answer'].values.tolist() 91 | gt_answer_with_value = [] 92 | predict_answer_with_value = [] 93 | 94 | total = 0.0 95 | correct = 0.0 96 | for i in range(len(gt_answer)): 97 | if not pd.isna(predict_answer[i]): 98 | total += 1 99 | gt_answer_with_value.append(gt_answer[i]) 100 | predict_answer_with_value.append(predict_answer[i]) 101 | if gt_answer[i] == predict_answer[i]: 102 | correct += 1 103 | 104 | gt_answer = gt_answer_with_value 105 | predict_answer = predict_answer_with_value 106 | 107 | print(total) 108 | print(correct / total) 109 | 110 | from sklearn.metrics import precision_recall_fscore_support 111 | precison, recall, fscore, _ = precision_recall_fscore_support(gt_answer, predict_answer, average='weighted') 112 | print('Precision: ', precison) 113 | print('Recall: ', recall) 114 | print('Fscore: ', fscore) 115 | 116 | from evaluate.utils import rouge_score, bleu_score, unique_sentence_percent, root_mean_square_error, \ 117 | mean_absolute_error, feature_detect, feature_matching_ratio, feature_coverage_ratio, feature_diversity 118 | 119 | tokens_of_processed_predict_exps = [list(jieba.cut(item, cut_all=False)) for item in predict_exp] 120 | tokens_of_processed_gt_exps = [list(jieba.cut(item, cut_all=False)) for item in gt_exp] 121 | # tokens_of_processed_predict_exps = [list(item) for item in predict_exp] 122 | # tokens_of_processed_gt_exps = [list(item) for item in gt_exp] 123 | 124 | processed_gt_exps = [' '.join(list(item)) for item in predict_exp] 125 | processed_predict_exps = [' '.join(list(item)) for item in gt_exp] 126 | 127 | BLEU1 = bleu_score(tokens_of_processed_gt_exps, tokens_of_processed_predict_exps, n_gram=1, smooth=False) 128 | BLEU2 = bleu_score(tokens_of_processed_gt_exps, tokens_of_processed_predict_exps, n_gram=2, smooth=False) 129 | BLEU4 = bleu_score(tokens_of_processed_gt_exps, tokens_of_processed_predict_exps, n_gram=4, smooth=False) 130 | ROUGE = rouge_score(processed_gt_exps, processed_predict_exps) 131 | 132 | print('BLEU-1 {:7.4f}'.format(BLEU1)) 133 | print('BLEU-2 {:7.4f}'.format(BLEU2)) 134 | print('BLEU-4 {:7.4f}'.format(BLEU4)) 135 | for (k, v) in ROUGE.items(): 136 | print('{} {:7.4f}'.format(k, v)) 137 | 138 | 139 | if __name__ == '__main__': 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument("--filepath", type=str, default="../../data/test_with_annotations.csv") 142 | parser.add_argument("--savepath", type=str, default="../exp/test_with_chatglm.csv") 143 | parser.add_argument("--modelpath", type=str, default="THUDM/chatglm-6b") 144 | parser.add_argument("--tokenizerpath", type=str, default="THUDM/chatglm-6b") 145 | args = parser.parse_args() 146 | prediction(args) 147 | evaluation(args) -------------------------------------------------------------------------------- /src/evaluation/evaluate_ft_result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # -------------------------------------------- 4 | # @FileName: calc_metrics.py 5 | # @Author: ljl 6 | # @Time: 2023/5/10 7 | # @Description: 8 | # -------------------------------------------- 9 | 10 | import pandas as pd 11 | import jieba 12 | 13 | filepath = 'test_predicted.csv' 14 | 15 | csv = pd.read_csv(filepath) 16 | 17 | gt_exp = csv['Explanation'].values.tolist() 18 | predict_exp = csv['explanation'].values.tolist() 19 | # process pd.na 20 | gt_exp = [item if not pd.isna(item) else "" for item in gt_exp] 21 | predict_exp = [item if not pd.isna(item) else "" for item in predict_exp] 22 | 23 | # gt_answer = csv['Answer'].values.tolist() 24 | # predict_answer = csv['answer_prediction'].values.tolist() 25 | # gt_answer_with_value = [] 26 | # predict_answer_with_value = [] 27 | # 28 | # total = 0.0 29 | # correct = 0.0 30 | # for i in range(len(gt_answer)): 31 | # if not pd.isna(predict_answer[i]): 32 | # total+=1 33 | # gt_answer_with_value.append(gt_answer[i]) 34 | # predict_answer_with_value.append(predict_answer[i]) 35 | # if gt_answer[i] == predict_answer[i]: 36 | # correct+=1 37 | # 38 | # 39 | # gt_answer = gt_answer_with_value 40 | # predict_answer = predict_answer_with_value 41 | # 42 | # print(total) 43 | # print(correct/total) 44 | 45 | from sklearn.metrics import precision_recall_fscore_support 46 | precison, recall, fscore, _ = precision_recall_fscore_support(gt_answer, predict_answer, average='weighted') 47 | print('Precision: ', precison) 48 | print('Recall: ', recall) 49 | print('Fscore: ', fscore) 50 | 51 | from src.evaluation.evaluate.utils import rouge_score, bleu_score, unique_sentence_percent, root_mean_square_error, mean_absolute_error, feature_detect, feature_matching_ratio, feature_coverage_ratio, feature_diversity 52 | 53 | tokens_of_processed_predict_exps = [list(jieba.cut(item,cut_all=False)) for item in predict_exp] 54 | tokens_of_processed_gt_exps = [list(jieba.cut(item,cut_all=False)) for item in gt_exp] 55 | 56 | # tokens_of_processed_predict_exps = [list(item) for item in predict_exp] 57 | # tokens_of_processed_gt_exps = [list(item) for item in gt_exp] 58 | processed_gt_exps = [' '.join(list(item)) for item in predict_exp] 59 | processed_predict_exps = [' '.join(list(item)) for item in gt_exp] 60 | 61 | BLEU1 = bleu_score(tokens_of_processed_gt_exps, tokens_of_processed_predict_exps, n_gram=1, smooth=False) 62 | BLEU2 = bleu_score(tokens_of_processed_gt_exps, tokens_of_processed_predict_exps, n_gram=2, smooth=False) 63 | BLEU4 = bleu_score(tokens_of_processed_gt_exps, tokens_of_processed_predict_exps, n_gram=4, smooth=False) 64 | ROUGE = rouge_score(processed_gt_exps, processed_predict_exps) 65 | 66 | print('BLEU-1 {:7.4f}'.format(BLEU1)) 67 | print('BLEU-2 {:7.4f}'.format(BLEU2)) 68 | print('BLEU-4 {:7.4f}'.format(BLEU4)) 69 | for (k, v) in ROUGE.items(): 70 | print('{} {:7.4f}'.format(k, v)) 71 | -------------------------------------------------------------------------------- /src/evaluation/evaluate_gpt_result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # -------------------------------------------- 4 | # @FileName: translate.py 5 | # @Author: ljl 6 | # @Time: 2023/5/4 7 | # @Description: 8 | # -------------------------------------------- 9 | 10 | import openai 11 | import argparse 12 | import os 13 | import time 14 | import jieba 15 | from multiprocessing import Pool 16 | import pandas as pd 17 | from tqdm import tqdm 18 | os.environ["HTTP_PROXY"] = "socks5h://127.0.0.1:13659" 19 | os.environ["HTTPS_PROXY"] = "socks5h://127.0.0.1:13659" 20 | # os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890" 21 | # os.environ["HTTPS_PROXY"] = "https://127.0.0.1:7890" 22 | 23 | def call_api(data,question_nums, model): 24 | results = [] 25 | try: 26 | for i, content in tqdm(enumerate(data)): 27 | result = "" 28 | try: 29 | completion = openai.ChatCompletion.create( 30 | model=model, 31 | # model="gpt-4", 32 | # model="gpt-4-0314", 33 | messages=[{"role": "user", "content": content}] 34 | ) 35 | result = completion.choices[0].message.content 36 | except Exception as e: 37 | print(str(e), flush=True) 38 | results.append(result) 39 | except Exception as e: 40 | print(str(e), flush=True) 41 | results.extend(["[]" for _ in range(len(data)-len(results))]) 42 | return results,question_nums 43 | 44 | def prediction(args): 45 | openai.api_key = args.api_key 46 | 47 | csv = pd.read_csv(args.filepath) 48 | questions = csv['Question'].values.tolist() 49 | options = csv['Options'].values.tolist() 50 | 51 | template = "返回格式为一个python列表,包含每道题的答案英文选项和解释 \n" \ 52 | "假设你是一位医疗行业专家,请回答下列几个问题。\n" \ 53 | "题目信息为:{} \n" \ 54 | "注意,每个题目的回答以一个字符串保存,返回答案的英文选项,并进行简要的解释。字符串输出格式限制为“答案:**,解释:**”" 55 | data = [] 56 | question_nums = [] 57 | step = 5 58 | 59 | for i in range(0,len(questions),step): 60 | question_group = "" 61 | question_num = min(step, len(questions)-i) 62 | for j in range(question_num): 63 | question_group+="{}.题目信息为 {}:{}\n".format(str(j+1),questions[i+j], options[i+j].replace('\n',',')) 64 | 65 | data.append(template.format(question_group)) 66 | question_nums.append(question_num) 67 | 68 | # data = data[:2] 69 | # question_nums = question_nums[:2] 70 | 71 | # multiprocessing 72 | num_of_processes = 1 73 | pool = Pool(processes=num_of_processes) 74 | pool_results = [] 75 | each_size = len(data) // num_of_processes 76 | for i in range(num_of_processes): 77 | if i0: 34 | return "".join(predict_ops) 35 | else: 36 | return "无答案" 37 | def parse_explanations(row): 38 | # 从'model_results'中提取答案部分(即选项) 39 | if not isinstance(row['model_result'],str): 40 | return "无答案" 41 | if '解释:' not in row['model_result']: 42 | original_result = row['model_result'] 43 | else: 44 | original_result = row['model_result'].split('解释:')[1].strip() 45 | return original_result 46 | def evaluate_reasoning(df): 47 | def add_spaces(l): 48 | return [' '.join(list(_)) for _ in l] 49 | source = '答案解析' 50 | target = 'parsed_explanation' 51 | df.dropna(subset=[source, target], inplace=True) 52 | tokens_predict = df[target].to_list() 53 | tokens_test = df[source].to_list() 54 | 55 | tokens_predict = add_spaces(tokens_predict) 56 | tokens_test = add_spaces(tokens_test) 57 | 58 | new_tokens_predict = [l.split() for l in tokens_predict] 59 | new_tokens_test = [ll.split() for ll in tokens_test] 60 | BLEU1 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=1, smooth=False) 61 | BLEU4 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=4, smooth=False) 62 | ROUGE = rouge_score(tokens_test, tokens_predict) 63 | 64 | print('BLEU-1 {:7.4f}'.format(BLEU1)) 65 | print('BLEU-4 {:7.4f}'.format(BLEU4)) 66 | for (k, v) in ROUGE.items(): 67 | if 'f_score' in k: 68 | print('{} {:7.4f}'.format(k, v)) 69 | 70 | def evaluate_prediction(df): 71 | correct = df[df['parsed_option']==df['答案']].shape[0] 72 | total = df.shape[0] 73 | num_no_answer = df[df['parsed_option']=='无答案'].shape[0] 74 | 75 | processed_gts = df['答案'].to_list() 76 | processed_results = df['parsed_option'].to_list() 77 | precison, recall, fscore, _ = precision_recall_fscore_support(processed_gts, processed_results, average='weighted') 78 | print('Precision: ', precison) 79 | print('Recall: ', recall) 80 | print('Fscore: ', fscore) 81 | print('Acc:{}'.format(correct/total*100)) 82 | print('The number of "No answers:"',num_no_answer) 83 | 84 | def main( 85 | csv_file_path: str = "../LoRA/output/medalpaca_4.csv", 86 | ): 87 | df = pd.read_csv(csv_file_path) 88 | 89 | df['parsed_option'] = df.apply(parse_options,axis=1) 90 | df['parsed_explanation'] = df.apply(parse_explanations,axis=1) 91 | 92 | print('Evaluation of prediction:') 93 | evaluate_prediction(df) 94 | print('*'*20) 95 | print('Evaluation of reasoning:') 96 | evaluate_reasoning(df) 97 | 98 | if __name__ == "__main__": 99 | fire.Fire(main) -------------------------------------------------------------------------------- /src/preprocess/dataset_dist.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/williamliujl/CMExam/fadb22c89beb1b7115dc36460ba792eb96b7b972/src/preprocess/dataset_dist.pdf -------------------------------------------------------------------------------- /src/preprocess/generate_prompt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # -------------------------------------------- 4 | # @FileName: generate_prompt.py 5 | # @Author: ljl 6 | # @Time: 2023/5/15 7 | # @Description: 8 | # -------------------------------------------- 9 | 10 | import pandas as pd 11 | import json 12 | import copy 13 | import argparse 14 | from prompt_templates import all_task_templates 15 | 16 | 17 | def main(args): 18 | 19 | filepath = args.filepath 20 | 21 | csv = pd.read_csv(filepath) 22 | 23 | # prompt_templates = ["1","2","3","4","5","6"] 24 | prompt_templates = args.templates.split(",") 25 | 26 | prompts = [] 27 | 28 | for i,data in enumerate(csv.values): 29 | 30 | question = data[csv.columns.values.tolist().index("Question")] 31 | options = data[csv.columns.values.tolist().index("Options")] 32 | explanation = data[csv.columns.values.tolist().index("Explanation")] 33 | option_lists = options.split("\n") 34 | answer = data[csv.columns.values.tolist().index("Answer")] 35 | if pd.isna(answer): 36 | continue 37 | answer_content = "" 38 | for option in option_lists: 39 | if option.split(" ")[0] == answer: 40 | answer_content = option.split(" ")[-1] 41 | 42 | for prompt_idx in prompt_templates: 43 | prompt_template = copy.deepcopy(all_task_templates[prompt_idx]) 44 | try: 45 | if prompt_idx == "1": 46 | prompt_template["prompt"] = prompt_template["prompt"].format(question, options) 47 | prompt_template["completion"] = prompt_template["completion"].format(answer) 48 | prompts.append(prompt_template) 49 | elif prompt_idx == "2": 50 | prompt_template["prompt"] = prompt_template["prompt"].format(question, options) 51 | prompt_template["completion"] = prompt_template["completion"].format(answer+" "+ answer_content) 52 | prompts.append(prompt_template) 53 | elif prompt_idx == "3": 54 | prompt_template["prompt"] = prompt_template["prompt"].format(question, options) 55 | prompt_template["completion"] = prompt_template["completion"].format(explanation) 56 | prompts.append(prompt_template) 57 | elif prompt_idx == "4": 58 | prompt_template["prompt"] = prompt_template["prompt"].format(question, options) 59 | prompt_template["completion"] = prompt_template["completion"].format(answer+" "+ answer_content, explanation) 60 | prompts.append(prompt_template) 61 | elif prompt_idx == "5": 62 | prompt_template["prompt"] = prompt_template["prompt"].format(question) 63 | prompt_template["completion"] = prompt_template["completion"].format(answer_content) 64 | prompts.append(prompt_template) 65 | elif prompt_idx == "6": 66 | prompt_template["prompt"] = prompt_template["prompt"].format(question) 67 | prompt_template["completion"] = prompt_template["completion"].format(answer_content, explanation) 68 | prompts.append(prompt_template) 69 | except Exception as e: 70 | print(data) 71 | 72 | # save json 73 | savepath = filepath.replace(".csv", ".json") 74 | with open(savepath, 'w') as f: 75 | for prompt in prompts: 76 | json_file = { 77 | "prompt":prompt["prompt"], 78 | "completion":prompt["completion"], 79 | "id":prompt["id"] 80 | } 81 | json_str = json.dumps(json_file,ensure_ascii=False) 82 | f.write(json_str + '\n') 83 | f.close() 84 | 85 | # save csv 86 | savepath = filepath.replace(".csv", "_prompt.json") 87 | csv["prompt"] = [prompt["prompt"] for prompt in prompts] 88 | csv["completion"] = [prompt["completion"] for prompt in prompts] 89 | csv["id"] = [prompt["id"] for prompt in prompts] 90 | csv.to_csv(savepath) 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--filepath", type=str, required=True) 95 | parser.add_argument("--templates", type=str, default="1,2", help="To generate prompts using different templates") 96 | args = parser.parse_args() 97 | main(args) 98 | -------------------------------------------------------------------------------- /src/preprocess/prompt_templates.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # -------------------------------------------- 4 | # @FileName: prompt_templates.py 5 | # @Author: ljl 6 | # @Time: 2023/5/15 7 | # @Description: 8 | # -------------------------------------------- 9 | 10 | all_task_templates = {} 11 | 12 | template = {} 13 | template['prompt'] = "问题: {}, \n 选项: {}" 14 | template['completion'] = "答案: {}" 15 | template['id'] = "1" 16 | all_task_templates["1"] = template 17 | 18 | template = {} 19 | template['prompt'] = "问题: {}, \n 选项: {}" 20 | template['completion'] = "答案: {}" 21 | template['id'] = "2" 22 | all_task_templates["2"] = template 23 | 24 | template = {} 25 | template['prompt'] = "问题: {}, \n 选项: {}" 26 | template['completion'] = "解释: {}" 27 | template['id'] = "3" 28 | all_task_templates["3"] = template 29 | 30 | template = {} 31 | template['prompt'] = "问题: {}, \n 选项: {}" 32 | template['completion'] = "答案: {}. \n 解释:{}" 33 | template['id'] = "4" 34 | all_task_templates["4"] = template 35 | 36 | template = {} 37 | template['prompt'] = "问题: {}" 38 | template['completion'] = "答案: {}" 39 | template['id'] = "5" 40 | all_task_templates["5"] = template 41 | 42 | template = {} 43 | template['prompt'] = "问题: {}" 44 | template['completion'] = "答案: {}. \n 解释: {}" 45 | template['id'] = "6" 46 | all_task_templates["6"] = template -------------------------------------------------------------------------------- /src/ptuning/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class ModelArguments: 7 | """ 8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 9 | """ 10 | 11 | model_name_or_path: str = field( 12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 13 | ) 14 | ptuning_checkpoint: str = field( 15 | default=None, metadata={"help": "Path to p-tuning v2 checkpoints"} 16 | ) 17 | config_name: Optional[str] = field( 18 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 19 | ) 20 | tokenizer_name: Optional[str] = field( 21 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 22 | ) 23 | cache_dir: Optional[str] = field( 24 | default=None, 25 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 26 | ) 27 | use_fast_tokenizer: bool = field( 28 | default=True, 29 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 30 | ) 31 | model_revision: str = field( 32 | default="main", 33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 34 | ) 35 | use_auth_token: bool = field( 36 | default=False, 37 | metadata={ 38 | "help": ( 39 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 40 | "with private models)." 41 | ) 42 | }, 43 | ) 44 | resize_position_embeddings: Optional[bool] = field( 45 | default=None, 46 | metadata={ 47 | "help": ( 48 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 49 | "the model's position embeddings." 50 | ) 51 | }, 52 | ) 53 | quantization_bit: Optional[int] = field( 54 | default=None 55 | ) 56 | pre_seq_len: Optional[int] = field( 57 | default=None 58 | ) 59 | prefix_projection: bool = field( 60 | default=False 61 | ) 62 | 63 | 64 | @dataclass 65 | class DataTrainingArguments: 66 | """ 67 | Arguments pertaining to what data we are going to input our model for training and eval. 68 | """ 69 | 70 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) 71 | 72 | dataset_name: Optional[str] = field( 73 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 74 | ) 75 | dataset_config_name: Optional[str] = field( 76 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 77 | ) 78 | prompt_column: Optional[str] = field( 79 | default=None, 80 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 81 | ) 82 | response_column: Optional[str] = field( 83 | default=None, 84 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 85 | ) 86 | history_column: Optional[str] = field( 87 | default=None, 88 | metadata={"help": "The name of the column in the datasets containing the history of chat."}, 89 | ) 90 | train_file: Optional[str] = field( 91 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 92 | ) 93 | validation_file: Optional[str] = field( 94 | default=None, 95 | metadata={ 96 | "help": ( 97 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 98 | ) 99 | }, 100 | ) 101 | test_file: Optional[str] = field( 102 | default=None, 103 | metadata={ 104 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 105 | }, 106 | ) 107 | overwrite_cache: bool = field( 108 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 109 | ) 110 | preprocessing_num_workers: Optional[int] = field( 111 | default=None, 112 | metadata={"help": "The number of processes to use for the preprocessing."}, 113 | ) 114 | max_source_length: Optional[int] = field( 115 | default=1024, 116 | metadata={ 117 | "help": ( 118 | "The maximum total input sequence length after tokenization. Sequences longer " 119 | "than this will be truncated, sequences shorter will be padded." 120 | ) 121 | }, 122 | ) 123 | max_target_length: Optional[int] = field( 124 | default=128, 125 | metadata={ 126 | "help": ( 127 | "The maximum total sequence length for target text after tokenization. Sequences longer " 128 | "than this will be truncated, sequences shorter will be padded." 129 | ) 130 | }, 131 | ) 132 | val_max_target_length: Optional[int] = field( 133 | default=None, 134 | metadata={ 135 | "help": ( 136 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 137 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 138 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 139 | "during ``evaluate`` and ``predict``." 140 | ) 141 | }, 142 | ) 143 | pad_to_max_length: bool = field( 144 | default=False, 145 | metadata={ 146 | "help": ( 147 | "Whether to pad all samples to model maximum sentence length. " 148 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 149 | "efficient on GPU but very bad for TPU." 150 | ) 151 | }, 152 | ) 153 | max_train_samples: Optional[int] = field( 154 | default=None, 155 | metadata={ 156 | "help": ( 157 | "For debugging purposes or quicker training, truncate the number of training examples to this " 158 | "value if set." 159 | ) 160 | }, 161 | ) 162 | max_eval_samples: Optional[int] = field( 163 | default=None, 164 | metadata={ 165 | "help": ( 166 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 167 | "value if set." 168 | ) 169 | }, 170 | ) 171 | max_predict_samples: Optional[int] = field( 172 | default=None, 173 | metadata={ 174 | "help": ( 175 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 176 | "value if set." 177 | ) 178 | }, 179 | ) 180 | num_beams: Optional[int] = field( 181 | default=None, 182 | metadata={ 183 | "help": ( 184 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 185 | "which is used during ``evaluate`` and ``predict``." 186 | ) 187 | }, 188 | ) 189 | ignore_pad_token_for_loss: bool = field( 190 | default=True, 191 | metadata={ 192 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 193 | }, 194 | ) 195 | source_prefix: Optional[str] = field( 196 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 197 | ) 198 | 199 | forced_bos_token: Optional[str] = field( 200 | default=None, 201 | metadata={ 202 | "help": ( 203 | "The token to force as the first generated token after the decoder_start_token_id." 204 | "Useful for multilingual models like mBART where the first generated token" 205 | "needs to be the target language token (Usually it is the target language token)" 206 | ) 207 | }, 208 | ) 209 | 210 | 211 | 212 | def __post_init__(self): 213 | if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None: 214 | raise ValueError("Need either a dataset name or a training/validation/test file.") 215 | else: 216 | if self.train_file is not None: 217 | extension = self.train_file.split(".")[-1] 218 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 219 | if self.validation_file is not None: 220 | extension = self.validation_file.split(".")[-1] 221 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 222 | if self.val_max_target_length is None: 223 | self.val_max_target_length = self.max_target_length 224 | 225 | -------------------------------------------------------------------------------- /src/ptuning/deepspeed.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": "auto", 3 | "zero_allow_untested_optimizer": true, 4 | "fp16": { 5 | "enabled": "auto", 6 | "loss_scale": 0, 7 | "initial_scale_power": 16, 8 | "loss_scale_window": 1000, 9 | "hysteresis": 2, 10 | "min_loss_scale": 1 11 | }, 12 | "zero_optimization": { 13 | "stage": 2, 14 | "allgather_partitions": true, 15 | "allgather_bucket_size": 5e8, 16 | "overlap_comm": false, 17 | "reduce_scatter": true, 18 | "reduce_bucket_size": 5e8, 19 | "contiguous_gradients" : true 20 | } 21 | } -------------------------------------------------------------------------------- /src/ptuning/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | import json 25 | 26 | import numpy as np 27 | from datasets import load_dataset 28 | import jieba 29 | from rouge_chinese import Rouge 30 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 31 | import torch 32 | 33 | import transformers 34 | from transformers import ( 35 | AutoConfig, 36 | AutoModel, 37 | AutoTokenizer, 38 | AutoTokenizer, 39 | DataCollatorForSeq2Seq, 40 | HfArgumentParser, 41 | Seq2SeqTrainingArguments, 42 | set_seed, 43 | ) 44 | from trainer_seq2seq import Seq2SeqTrainer 45 | 46 | from arguments import ModelArguments, DataTrainingArguments 47 | 48 | logger = logging.getLogger(__name__) 49 | 50 | def main(): 51 | 52 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 53 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 54 | # If we pass only one argument to the script and it's the path to a json file, 55 | # let's parse it to get our arguments. 56 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 57 | else: 58 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 59 | 60 | # Setup logging 61 | logging.basicConfig( 62 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 63 | datefmt="%m/%d/%Y %H:%M:%S", 64 | handlers=[logging.StreamHandler(sys.stdout)], 65 | ) 66 | 67 | if training_args.should_log: 68 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 69 | transformers.utils.logging.set_verbosity_info() 70 | 71 | log_level = training_args.get_process_log_level() 72 | logger.setLevel(log_level) 73 | # datasets.utils.logging.set_verbosity(log_level) 74 | transformers.utils.logging.set_verbosity(log_level) 75 | transformers.utils.logging.enable_default_handler() 76 | transformers.utils.logging.enable_explicit_format() 77 | 78 | # Log on each process the small summary: 79 | logger.warning( 80 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 81 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 82 | ) 83 | logger.info(f"Training/evaluation parameters {training_args}") 84 | 85 | # Set seed before initializing model. 86 | set_seed(training_args.seed) 87 | 88 | # Load dataset 89 | data_files = {} 90 | if data_args.train_file is not None: 91 | data_files["train"] = data_args.train_file 92 | extension = data_args.train_file.split(".")[-1] 93 | if data_args.validation_file is not None: 94 | data_files["validation"] = data_args.validation_file 95 | extension = data_args.validation_file.split(".")[-1] 96 | if data_args.test_file is not None: 97 | data_files["test"] = data_args.test_file 98 | extension = data_args.test_file.split(".")[-1] 99 | 100 | raw_datasets = load_dataset( 101 | extension, 102 | data_files=data_files, 103 | cache_dir=model_args.cache_dir, 104 | use_auth_token=True if model_args.use_auth_token else None, 105 | ) 106 | 107 | # Load pretrained model and tokenizer 108 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 109 | config.pre_seq_len = model_args.pre_seq_len 110 | config.prefix_projection = model_args.prefix_projection 111 | 112 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 113 | 114 | if model_args.ptuning_checkpoint is not None: 115 | # Evaluation 116 | # Loading extra state dict of prefix encoder 117 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) 118 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin")) 119 | new_prefix_state_dict = {} 120 | for k, v in prefix_state_dict.items(): 121 | if k.startswith("transformer.prefix_encoder."): 122 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v 123 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) 124 | else: 125 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) 126 | 127 | if model_args.quantization_bit is not None: 128 | print(f"Quantized to {model_args.quantization_bit} bit") 129 | model = model.quantize(model_args.quantization_bit) 130 | if model_args.pre_seq_len is not None: 131 | # P-tuning v2 132 | model = model.half() 133 | model.transformer.prefix_encoder.float() 134 | else: 135 | # Finetune 136 | model = model.float() 137 | 138 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 139 | 140 | # Preprocessing the datasets. 141 | # We need to tokenize inputs and targets. 142 | if training_args.do_train: 143 | column_names = raw_datasets["train"].column_names 144 | elif training_args.do_eval: 145 | column_names = raw_datasets["validation"].column_names 146 | elif training_args.do_predict: 147 | column_names = raw_datasets["test"].column_names 148 | else: 149 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 150 | return 151 | 152 | # Get the column names for input/target. 153 | prompt_column = data_args.prompt_column 154 | response_column = data_args.response_column 155 | history_column = data_args.history_column 156 | 157 | # Temporarily set max_target_length for training. 158 | max_target_length = data_args.max_target_length 159 | 160 | def preprocess_function_eval(examples): 161 | inputs, targets = [], [] 162 | for i in range(len(examples[prompt_column])): 163 | if examples[prompt_column][i] and examples[response_column][i]: 164 | query = examples[prompt_column][i] 165 | if history_column is None or len(examples[history_column][i]) == 0: 166 | prompt = query 167 | else: 168 | prompt = "" 169 | history = examples[history_column][i] 170 | for turn_idx, (old_query, response) in enumerate(history): 171 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response) 172 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) 173 | inputs.append(prompt) 174 | targets.append(examples[response_column][i]) 175 | 176 | inputs = [prefix + inp for inp in inputs] 177 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True) 178 | labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True) 179 | 180 | if data_args.ignore_pad_token_for_loss: 181 | labels["input_ids"] = [ 182 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 183 | ] 184 | model_inputs["labels"] = labels["input_ids"] 185 | 186 | return model_inputs 187 | 188 | def preprocess_function_train(examples): 189 | max_seq_length = data_args.max_source_length + data_args.max_target_length 190 | 191 | model_inputs = { 192 | "input_ids": [], 193 | "labels": [], 194 | } 195 | for i in range(len(examples[prompt_column])): 196 | if examples[prompt_column][i] and examples[response_column][i]: 197 | query, answer = examples[prompt_column][i], examples[response_column][i] 198 | 199 | if history_column is None: 200 | prompt = query 201 | else: 202 | prompt = "" 203 | history = examples[history_column][i] 204 | for turn_idx, (old_query, response) in enumerate(history): 205 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response) 206 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) 207 | 208 | prompt = prefix + prompt 209 | a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) 210 | b_ids = tokenizer.encode(text=answer, add_special_tokens=False) 211 | 212 | if len(a_ids) > data_args.max_source_length - 1: 213 | a_ids = a_ids[: data_args.max_source_length - 1] 214 | 215 | if len(b_ids) > data_args.max_target_length - 2: 216 | b_ids = b_ids[: data_args.max_target_length - 2] 217 | 218 | input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids) 219 | 220 | # import pdb;pdb.set_trace() 221 | context_length = input_ids.index(tokenizer.bos_token_id) 222 | mask_position = context_length - 1 223 | labels = [-100] * context_length + input_ids[mask_position+1:] 224 | 225 | pad_len = max_seq_length - len(input_ids) 226 | input_ids = input_ids + [tokenizer.pad_token_id] * pad_len 227 | labels = labels + [tokenizer.pad_token_id] * pad_len 228 | if data_args.ignore_pad_token_for_loss: 229 | labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels] 230 | 231 | model_inputs["input_ids"].append(input_ids) 232 | model_inputs["labels"].append(labels) 233 | 234 | return model_inputs 235 | 236 | def print_dataset_example(example): 237 | print("input_ids",example["input_ids"]) 238 | print("inputs", tokenizer.decode(example["input_ids"])) 239 | print("label_ids", example["labels"]) 240 | print("labels", tokenizer.decode(example["labels"])) 241 | 242 | if training_args.do_train: 243 | if "train" not in raw_datasets: 244 | raise ValueError("--do_train requires a train dataset") 245 | train_dataset = raw_datasets["train"] 246 | if data_args.max_train_samples is not None: 247 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 248 | train_dataset = train_dataset.select(range(max_train_samples)) 249 | with training_args.main_process_first(desc="train dataset map pre-processing"): 250 | train_dataset = train_dataset.map( 251 | preprocess_function_train, 252 | batched=True, 253 | num_proc=data_args.preprocessing_num_workers, 254 | remove_columns=column_names, 255 | load_from_cache_file=not data_args.overwrite_cache, 256 | desc="Running tokenizer on train dataset", 257 | ) 258 | print_dataset_example(train_dataset[0]) 259 | 260 | if training_args.do_eval: 261 | max_target_length = data_args.val_max_target_length 262 | if "validation" not in raw_datasets: 263 | raise ValueError("--do_eval requires a validation dataset") 264 | eval_dataset = raw_datasets["validation"] 265 | if data_args.max_eval_samples is not None: 266 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 267 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 268 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 269 | eval_dataset = eval_dataset.map( 270 | preprocess_function_eval, 271 | batched=True, 272 | num_proc=data_args.preprocessing_num_workers, 273 | remove_columns=column_names, 274 | load_from_cache_file=not data_args.overwrite_cache, 275 | desc="Running tokenizer on validation dataset", 276 | ) 277 | print_dataset_example(eval_dataset[0]) 278 | 279 | if training_args.do_predict: 280 | max_target_length = data_args.val_max_target_length 281 | if "test" not in raw_datasets: 282 | raise ValueError("--do_predict requires a test dataset") 283 | predict_dataset = raw_datasets["test"] 284 | if data_args.max_predict_samples is not None: 285 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 286 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 287 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 288 | predict_dataset = predict_dataset.map( 289 | preprocess_function_eval, 290 | batched=True, 291 | num_proc=data_args.preprocessing_num_workers, 292 | remove_columns=column_names, 293 | load_from_cache_file=not data_args.overwrite_cache, 294 | desc="Running tokenizer on prediction dataset", 295 | ) 296 | print_dataset_example(predict_dataset[0]) 297 | 298 | # Data collator 299 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 300 | data_collator = DataCollatorForSeq2Seq( 301 | tokenizer, 302 | model=model, 303 | label_pad_token_id=label_pad_token_id, 304 | pad_to_multiple_of=None, 305 | padding=False 306 | ) 307 | 308 | # Metric 309 | def compute_metrics(eval_preds): 310 | preds, labels = eval_preds 311 | if isinstance(preds, tuple): 312 | preds = preds[0] 313 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 314 | if data_args.ignore_pad_token_for_loss: 315 | # Replace -100 in the labels as we can't decode them. 316 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 317 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 318 | 319 | score_dict = { 320 | "rouge-1": [], 321 | "rouge-2": [], 322 | "rouge-l": [], 323 | "bleu-4": [] 324 | } 325 | for pred, label in zip(decoded_preds, decoded_labels): 326 | hypothesis = list(jieba.cut(pred)) 327 | reference = list(jieba.cut(label)) 328 | rouge = Rouge() 329 | scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference)) 330 | result = scores[0] 331 | 332 | for k, v in result.items(): 333 | score_dict[k].append(round(v["f"] * 100, 4)) 334 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) 335 | score_dict["bleu-4"].append(round(bleu_score * 100, 4)) 336 | 337 | for k, v in score_dict.items(): 338 | score_dict[k] = float(np.mean(v)) 339 | return score_dict 340 | 341 | # Override the decoding parameters of Seq2SeqTrainer 342 | training_args.generation_max_length = ( 343 | training_args.generation_max_length 344 | if training_args.generation_max_length is not None 345 | else data_args.val_max_target_length 346 | ) 347 | training_args.generation_num_beams = ( 348 | data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 349 | ) 350 | # Initialize our Trainer 351 | trainer = Seq2SeqTrainer( 352 | model=model, 353 | args=training_args, 354 | train_dataset=train_dataset if training_args.do_train else None, 355 | eval_dataset=eval_dataset if training_args.do_eval else None, 356 | tokenizer=tokenizer, 357 | data_collator=data_collator, 358 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 359 | save_prefixencoder=model_args.pre_seq_len is not None 360 | ) 361 | 362 | # Training 363 | if training_args.do_train: 364 | checkpoint = None 365 | if training_args.resume_from_checkpoint is not None: 366 | checkpoint = training_args.resume_from_checkpoint 367 | # elif last_checkpoint is not None: 368 | # checkpoint = last_checkpoint 369 | model.gradient_checkpointing_enable() 370 | model.enable_input_require_grads() 371 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 372 | # trainer.save_model() # Saves the tokenizer too for easy upload 373 | 374 | metrics = train_result.metrics 375 | max_train_samples = ( 376 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 377 | ) 378 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 379 | 380 | trainer.log_metrics("train", metrics) 381 | trainer.save_metrics("train", metrics) 382 | trainer.save_state() 383 | 384 | # Evaluation 385 | results = {} 386 | if training_args.do_eval: 387 | logger.info("*** Evaluate ***") 388 | metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=512, temperature=0.95) 389 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 390 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 391 | 392 | trainer.log_metrics("eval", metrics) 393 | trainer.save_metrics("eval", metrics) 394 | 395 | if training_args.do_predict: 396 | logger.info("*** Predict ***") 397 | 398 | predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=512, do_sample=True, top_p=0.7, temperature=0.95) 399 | metrics = predict_results.metrics 400 | max_predict_samples = ( 401 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 402 | ) 403 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 404 | 405 | trainer.log_metrics("predict", metrics) 406 | trainer.save_metrics("predict", metrics) 407 | 408 | if trainer.is_world_process_zero(): 409 | if training_args.predict_with_generate: 410 | predictions = tokenizer.batch_decode( 411 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 412 | ) 413 | predictions = [pred.strip() for pred in predictions] 414 | labels = tokenizer.batch_decode( 415 | predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True 416 | ) 417 | labels = [label.strip() for label in labels] 418 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") 419 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 420 | for p, l in zip(predictions, labels): 421 | res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False) 422 | writer.write(f"{res}\n") 423 | return results 424 | 425 | 426 | def _mp_fn(index): 427 | # For xla_spawn (TPUs) 428 | main() 429 | 430 | 431 | if __name__ == "__main__": 432 | main() 433 | -------------------------------------------------------------------------------- /src/ptuning/prediction.sh: -------------------------------------------------------------------------------- 1 | PRE_SEQ_LEN=128 2 | CHECKPOINT=0523-bio_prompt_1-chatglm-6b-pt-128-2e-2-bs8-accumulation2 3 | STEP=34900 4 | 5 | CUDA_VISIBLE_DEVICES=1 python3 main.py \ 6 | --do_predict \ 7 | --validation_file ../../data/val_prompt.json \ 8 | --test_file ../../data/test_prompt.json \ 9 | --overwrite_cache \ 10 | --prompt_column prompt \ 11 | --response_column completion \ 12 | --model_name_or_path /home/shinian.ljl/projects/ChatGLM-6B/THUDM/chatglm-6b \ 13 | --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \ 14 | --output_dir ./output/$CHECKPOINT \ 15 | --overwrite_output_dir \ 16 | --max_source_length 256 \ 17 | --max_target_length 256 \ 18 | --per_device_eval_batch_size 1 \ 19 | --predict_with_generate \ 20 | --pre_seq_len $PRE_SEQ_LEN 21 | -------------------------------------------------------------------------------- /src/ptuning/train.sh: -------------------------------------------------------------------------------- 1 | PRE_SEQ_LEN=128 2 | LR=2e-2 3 | 4 | CUDA_VISIBLE_DEVICES=1 python3 main.py \ 5 | --do_train \ 6 | --train_file /home/shinian.ljl/data/bio/CMedQA/train_prompt_1.json \ 7 | --validation_file /home/shinian.ljl/data/bio/CMedQA/val_prompt_1.json \ 8 | --prompt_column prompt \ 9 | --response_column completion \ 10 | --overwrite_cache \ 11 | --model_name_or_path /home/shinian.ljl/projects/ChatGLM-6B/THUDM/chatglm-6b \ 12 | --output_dir output/0813-bio_prompt_1-chatglm-6b-pt-$PRE_SEQ_LEN-$LR-bs8-accumulation2 \ 13 | --overwrite_output_dir \ 14 | --max_source_length 256 \ 15 | --max_target_length 256 \ 16 | --per_device_train_batch_size 8 \ 17 | --per_device_eval_batch_size 8 \ 18 | --gradient_accumulation_steps 2 \ 19 | --predict_with_generate \ 20 | --max_steps 50000 \ 21 | --logging_steps 10 \ 22 | --save_steps 500 \ 23 | --learning_rate $LR \ 24 | --pre_seq_len $PRE_SEQ_LEN \ 25 | --report_to wandb -------------------------------------------------------------------------------- /src/ptuning/trainer_seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Dict, List, Optional, Tuple, Union 16 | 17 | import torch 18 | from torch import nn 19 | from torch.utils.data import Dataset 20 | 21 | from transformers.deepspeed import is_deepspeed_zero3_enabled 22 | from trainer import Trainer 23 | from transformers.trainer_utils import PredictionOutput 24 | from transformers.utils import logging 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | class Seq2SeqTrainer(Trainer): 31 | def evaluate( 32 | self, 33 | eval_dataset: Optional[Dataset] = None, 34 | ignore_keys: Optional[List[str]] = None, 35 | metric_key_prefix: str = "eval", 36 | **gen_kwargs 37 | ) -> Dict[str, float]: 38 | """ 39 | Run evaluation and returns metrics. 40 | 41 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent 42 | (pass it to the init `compute_metrics` argument). 43 | 44 | You can also subclass and override this method to inject custom behavior. 45 | 46 | Args: 47 | eval_dataset (`Dataset`, *optional*): 48 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns 49 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` 50 | method. 51 | ignore_keys (`List[str]`, *optional*): 52 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 53 | gathering predictions. 54 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 55 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 56 | "eval_bleu" if the prefix is `"eval"` (default) 57 | max_length (`int`, *optional*): 58 | The maximum target length to use when predicting with the generate method. 59 | num_beams (`int`, *optional*): 60 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 61 | beam search. 62 | gen_kwargs: 63 | Additional `generate` specific kwargs. 64 | 65 | Returns: 66 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The 67 | dictionary also contains the epoch number which comes from the training state. 68 | """ 69 | 70 | gen_kwargs = gen_kwargs.copy() 71 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 72 | gen_kwargs["max_length"] = self.args.generation_max_length 73 | gen_kwargs["num_beams"] = ( 74 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams 75 | ) 76 | self._gen_kwargs = gen_kwargs 77 | 78 | return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 79 | 80 | def predict( 81 | self, 82 | test_dataset: Dataset, 83 | ignore_keys: Optional[List[str]] = None, 84 | metric_key_prefix: str = "test", 85 | **gen_kwargs 86 | ) -> PredictionOutput: 87 | """ 88 | Run prediction and returns predictions and potential metrics. 89 | 90 | Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method 91 | will also return metrics, like in `evaluate()`. 92 | 93 | Args: 94 | test_dataset (`Dataset`): 95 | Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the 96 | `model.forward()` method are automatically removed. Has to implement the method `__len__` 97 | ignore_keys (`List[str]`, *optional*): 98 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 99 | gathering predictions. 100 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 101 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 102 | "eval_bleu" if the prefix is `"eval"` (default) 103 | max_length (`int`, *optional*): 104 | The maximum target length to use when predicting with the generate method. 105 | num_beams (`int`, *optional*): 106 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 107 | beam search. 108 | gen_kwargs: 109 | Additional `generate` specific kwargs. 110 | 111 | 112 | 113 | If your predictions or labels have different sequence lengths (for instance because you're doing dynamic 114 | padding in a token classification task) the predictions will be padded (on the right) to allow for 115 | concatenation into one array. The padding index is -100. 116 | 117 | 118 | 119 | Returns: *NamedTuple* A namedtuple with the following keys: 120 | 121 | - predictions (`np.ndarray`): The predictions on `test_dataset`. 122 | - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). 123 | - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained 124 | labels). 125 | """ 126 | 127 | gen_kwargs = gen_kwargs.copy() 128 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 129 | gen_kwargs["max_length"] = self.args.generation_max_length 130 | gen_kwargs["num_beams"] = ( 131 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams 132 | ) 133 | self._gen_kwargs = gen_kwargs 134 | 135 | 136 | return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 137 | 138 | def prediction_step( 139 | self, 140 | model: nn.Module, 141 | inputs: Dict[str, Union[torch.Tensor, Any]], 142 | prediction_loss_only: bool, 143 | ignore_keys: Optional[List[str]] = None, 144 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 145 | """ 146 | Perform an evaluation step on `model` using `inputs`. 147 | 148 | Subclass and override to inject custom behavior. 149 | 150 | Args: 151 | model (`nn.Module`): 152 | The model to evaluate. 153 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 154 | The inputs and targets of the model. 155 | 156 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 157 | argument `labels`. Check your model's documentation for all accepted arguments. 158 | prediction_loss_only (`bool`): 159 | Whether or not to return the loss only. 160 | 161 | Return: 162 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and 163 | labels (each being optional). 164 | """ 165 | 166 | if not self.args.predict_with_generate or prediction_loss_only: 167 | return super().prediction_step( 168 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 169 | ) 170 | 171 | has_labels = "labels" in inputs 172 | inputs = self._prepare_inputs(inputs) 173 | 174 | # XXX: adapt synced_gpus for fairscale as well 175 | gen_kwargs = self._gen_kwargs.copy() 176 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 177 | gen_kwargs["max_length"] = self.model.config.max_length 178 | gen_kwargs["num_beams"] = ( 179 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams 180 | ) 181 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False 182 | gen_kwargs["synced_gpus"] = ( 183 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus 184 | ) 185 | 186 | if "attention_mask" in inputs: 187 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) 188 | if "position_ids" in inputs: 189 | gen_kwargs["position_ids"] = inputs.get("position_ids", None) 190 | if "global_attention_mask" in inputs: 191 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) 192 | 193 | # prepare generation inputs 194 | # some encoder-decoder models can have varying encoder's and thus 195 | # varying model input names 196 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: 197 | generation_inputs = inputs[self.model.encoder.main_input_name] 198 | else: 199 | generation_inputs = inputs[self.model.main_input_name] 200 | 201 | gen_kwargs["input_ids"] = generation_inputs 202 | gen_kwargs["num_return_sequences"] = gen_kwargs["num_beams"] 203 | generated_tokens = self.model.generate(**gen_kwargs) 204 | generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] 205 | 206 | # in case the batch is shorter than max length, the output should be padded 207 | if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]: 208 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 209 | elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < ( 210 | gen_kwargs["max_new_tokens"] + 1 211 | ): 212 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) 213 | 214 | loss = None 215 | 216 | if self.args.prediction_loss_only: 217 | return (loss, None, None) 218 | 219 | if has_labels: 220 | labels = inputs["labels"] 221 | if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]: 222 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) 223 | elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < ( 224 | gen_kwargs["max_new_tokens"] + 1 225 | ): 226 | labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1)) 227 | else: 228 | labels = None 229 | 230 | return (loss, generated_tokens, labels) 231 | 232 | def _pad_tensors_to_max_len(self, tensor, max_length): 233 | if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): 234 | # If PAD token is not defined at least EOS token has to be defined 235 | pad_token_id = ( 236 | self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id 237 | ) 238 | else: 239 | if self.model.config.pad_token_id is not None: 240 | pad_token_id = self.model.config.pad_token_id 241 | else: 242 | raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") 243 | 244 | padded_tensor = pad_token_id * torch.ones( 245 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device 246 | ) 247 | padded_tensor[:, : tensor.shape[-1]] = tensor 248 | return padded_tensor 249 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.18.0 2 | aiofiles==23.1.0 3 | aiohttp==3.8.4 4 | aiosignal==1.3.1 5 | altair==4.2.2 6 | anyio==3.6.2 7 | appdirs==1.4.4 8 | async-timeout==4.0.2 9 | attrs==23.1.0 10 | certifi==2022.12.7 11 | charset-normalizer==3.1.0 12 | click==8.1.3 13 | cmake==3.26.3 14 | contourpy==1.0.7 15 | cpm-kernels==1.0.11 16 | cycler==0.11.0 17 | datasets==2.12.0 18 | dill==0.3.6 19 | docker-pycreds==0.4.0 20 | entrypoints==0.4 21 | fastapi==0.95.1 22 | ffmpy==0.3.0 23 | filelock==3.12.0 24 | fonttools==4.39.3 25 | frozenlist==1.3.3 26 | fsspec==2023.4.0 27 | gitdb==4.0.10 28 | GitPython==3.1.31 29 | gradio==3.27.0 30 | gradio_client==0.1.3 31 | h11==0.14.0 32 | httpcore==0.17.0 33 | httpx==0.24.0 34 | huggingface-hub==0.13.4 35 | idna==3.4 36 | importlib-metadata==6.6.0 37 | importlib-resources==5.12.0 38 | jieba==0.42.1 39 | Jinja2==3.1.2 40 | joblib==1.2.0 41 | jsonschema==4.17.3 42 | kiwisolver==1.4.4 43 | latex2mathml==3.75.2 44 | linkify-it-py==2.0.0 45 | lit==16.0.1 46 | Markdown==3.4.3 47 | markdown-it-py==2.2.0 48 | MarkupSafe==2.1.2 49 | matplotlib==3.7.1 50 | mdit-py-plugins==0.3.3 51 | mdtex2html==1.2.0 52 | mdurl==0.1.2 53 | mpmath==1.3.0 54 | multidict==6.0.4 55 | multiprocess==0.70.14 56 | networkx==3.1 57 | nltk==3.8.1 58 | numpy==1.24.3 59 | nvidia-cublas-cu11==11.10.3.66 60 | nvidia-cuda-cupti-cu11==11.7.101 61 | nvidia-cuda-nvrtc-cu11==11.7.99 62 | nvidia-cuda-runtime-cu11==11.7.99 63 | nvidia-cudnn-cu11==8.5.0.96 64 | nvidia-cufft-cu11==10.9.0.58 65 | nvidia-curand-cu11==10.2.10.91 66 | nvidia-cusolver-cu11==11.4.0.1 67 | nvidia-cusparse-cu11==11.7.4.91 68 | nvidia-nccl-cu11==2.14.3 69 | nvidia-nvtx-cu11==11.7.91 70 | orjson==3.8.10 71 | packaging==23.1 72 | pandas==2.0.0 73 | pathtools==0.1.2 74 | peft==0.4.0 75 | Pillow==9.5.0 76 | protobuf==4.22.3 77 | psutil==5.9.5 78 | pyarrow==12.0.0 79 | pydantic==1.10.7 80 | pydub==0.25.1 81 | pyparsing==3.0.9 82 | pyrsistent==0.19.3 83 | python-dateutil==2.8.2 84 | python-multipart==0.0.6 85 | pytz==2023.3 86 | PyYAML==6.0 87 | regex==2023.3.23 88 | requests==2.28.2 89 | responses==0.18.0 90 | rouge-chinese==1.0.3 91 | safetensors==0.3.1 92 | semantic-version==2.10.0 93 | sentencepiece==0.1.98 94 | sentry-sdk==1.23.0 95 | setproctitle==1.3.2 96 | six==1.16.0 97 | smmap==5.0.0 98 | sniffio==1.3.0 99 | starlette==0.26.1 100 | sympy==1.11.1 101 | tokenizers==0.13.3 102 | toolz==0.12.0 103 | torch==2.0.0 104 | tqdm==4.65.0 105 | transformers==4.27.1 106 | triton==2.0.0 107 | typing_extensions==4.5.0 108 | tzdata==2023.3 109 | uc-micro-py==1.0.1 110 | urllib3==1.26.15 111 | uvicorn==0.21.1 112 | wandb==0.15.2 113 | websockets==11.0.2 114 | xxhash==3.2.0 115 | yarl==1.9.1 116 | zipp==3.15.0 117 | --------------------------------------------------------------------------------