├── .gitignore
├── LICENSE
├── README.md
├── conference_material
├── poster.pdf
└── presentation.pdf
├── data
├── test_with_annotations.csv
├── train.csv
└── val.csv
├── docs
├── Annotation Guidelines.txt
├── Comparison_with_MLEC-QA.jpg
├── example.png
└── overall_comparison.jpg
└── src
├── LoRA
├── finetune.py
├── generate.py
├── infer.py
├── scripts
│ ├── finetune.sh
│ ├── infer_ori.sh
│ └── infer_sft.sh
├── templates
│ ├── README.md
│ └── med_template.json
└── utils
│ ├── README.md
│ ├── __init__.py
│ ├── data_format_transform.py
│ └── prompter.py
├── evaluation
├── evaluate
│ ├── bleu.py
│ ├── metrics4rec.py
│ ├── rouge.py
│ └── utils.py
├── evaluate_chatglm_result.py
├── evaluate_ft_result.py
├── evaluate_gpt_result.py
└── evaluate_lora_result.py
├── preprocess
├── data stats.ipynb
├── dataset_dist.pdf
├── generate_prompt.py
└── prompt_templates.py
├── ptuning
├── arguments.py
├── deepspeed.json
├── main.py
├── prediction.sh
├── train.sh
├── trainer.py
└── trainer_seq2seq.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [Junling Liu] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Note: If you are looking for a multimodal dataset, check out our new dataset, **ChiMed-VL-Instruction**, with 469,441 vision-language QA pairs: [https://paperswithcode.com/dataset/qilin-med-vl](https://paperswithcode.com/dataset/qilin-med-vl))
2 |
3 | This paper was presented at NeurIPS 2023, New Orleans, Louisana. See here for the [poster](conference_material/poster.pdf) and [slides](conference_material/presentation.pdf).
4 |
5 | # Benchmarking Large Language Models on CMExam - A Comprehensive Chinese Medical Exam Dataset
6 |
7 | ## Introduction
8 |
9 | CMExam is a dataset sourced from the Chinese National Medical Licensing Examination. It consists of 60K+ multiple-choice questions and five additional question-wise annotations, including disease groups, clinical departments, medical disciplines, areas of competency, and question difficulty levels. Alongside the dataset, comprehensive benchmarks were conducted on representative LLMs on CMExam.
10 |
11 |
12 |
13 | ## Dataset Statistics
14 | | | Train | Val | Test | Total |
15 | |----------------------------|---------------|---------------|---------------|---------------|
16 | | Question | 54,497 | 6,811 | 6,811 | 68,119 |
17 | | Vocab | 4,545 | 3,620 | 3,599 | 4,629 |
18 | | Max Q tokens | 676 | 500 | 585 | 676 |
19 | | Max A tokens | 5 | 5 | 5 | 5 |
20 | | Max E tokens | 2,999 | 2,678 | 2,680 | 2,999 |
21 | | Avg Q tokens | 29.78 | 30.07 | 32.63 | 30.83 |
22 | | Avg A tokens | 1.08 | 1.07 | 1.07 | 1.07 |
23 | | Avg E tokens | 186.24 | 188.95 | 201.44 | 192.21 |
24 | | Median (Q1, Q3) Q tokens | 17 (12, 32) | 18 (12, 32) | 18 (12, 37) | 18 (12, 32) |
25 | | Median (Q1, Q3) A tokens | 1 (1, 1) | 1 (1, 1) | 1 (1, 1) | 1 (1, 1) |
26 | | Median (Q1, Q3) E tokens | 146 (69, 246) | 143 (65, 247) | 158 (80, 263) | 146 (69, 247) |
27 |
28 | \*Q: Question; A: Answer; E: Explanation
29 |
30 | ## Annotation Characteristics
31 | | Annotation Content | References | Unique values |
32 | |----------------------------|-----------------------------|---------------|
33 | | Disease Groups | The 11th revision of ICD-11 | 27 |
34 | | Clinical Departments | The Directory of Medical Institution Diagnostic and Therapeutic Categories (DMIDTC) | 36 |
35 | | Medical Disciplines | List of Graduate Education Disciplinary Majors (2022) | 7 |
36 | | Medical Competencies | Medical Professionals | 4 |
37 | | Difficulty Level | Human Performance | 5 |
38 |
39 | ## Benchmarks
40 |
41 | Alongside the dataset, we further conducted thorough experiments with representative LLMs and QA algorithms on CMExam.
42 |
43 |
44 |
45 | ## Deployment
46 |
47 | To deploy this project run
48 |
49 | ### Environment Setup
50 | ```
51 | cd src
52 | pip install -r requirements.txt
53 | ```
54 | ### Data Preprocess
55 | ```
56 | cd preprocess
57 | python generate_prompt.py
58 | ```
59 |
60 | ### Ptuning
61 | ```
62 | cd ../ptuning
63 | bash train.sh
64 | bash prediction.sh
65 | ```
66 |
67 | ### LoRA
68 | ```
69 | cd ../LoRA
70 | bash ./scripts/finetune.sh
71 | bash ./scripts/infer_ori.sh
72 | bash ./scripts/infer_sft.sh
73 | ```
74 |
75 | ### Evaluation
76 | ```
77 | cd ../evaluation
78 | python evaluate_lora_results.py --csv_file_path path/to/csv/file
79 | ```
80 |
81 | ## Side notes
82 | ### Limitations:
83 | - Excluding non-textual questions may introduce biases.
84 | - BLEU and ROUGE metrics are inadequate for fully assessing explanations; better expert analysis needed in future.
85 | ### Ethics in Data Collection:
86 | - Adheres to legal and ethical guidelines.
87 | - Authenticated and accurate for evaluating LLMs.
88 | - Intended for academic/research use only; commercial misuse prohibited.
89 | - Users should acknowledge dataset limitations and specific context.
90 | - Not for assessing individual medical competence or patient diagnosis.
91 | ### Future directions:
92 | - Translate to English (in-progress)
93 | - Include multimodal information (our new dataset ChiMed-Vision-Language-Instruction - 469,441 QA pairs: [https://paperswithcode.com/dataset/qilin-med-vl](https://paperswithcode.com/dataset/qilin-med-vl))
94 |
95 | ## Citation
96 | Benchmarking Large Language Models on CMExam -- A Comprehensive Chinese Medical Exam Dataset
97 | https://arxiv.org/abs/2306.03030
98 |
99 | ```
100 | @article{liu2023benchmarking,
101 | title={Benchmarking Large Language Models on CMExam--A Comprehensive Chinese Medical Exam Dataset},
102 | author={Liu, Junling and Zhou, Peilin and Hua, Yining and Chong, Dading and Tian, Zhongyu and Liu, Andrew and Wang, Helin and You, Chenyu and Guo, Zhenhua and Zhu, Lei and others},
103 | journal={arXiv preprint arXiv:2306.03030},
104 | year={2023}
105 | }
106 | ```
107 |
--------------------------------------------------------------------------------
/conference_material/poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williamliujl/CMExam/fadb22c89beb1b7115dc36460ba792eb96b7b972/conference_material/poster.pdf
--------------------------------------------------------------------------------
/conference_material/presentation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williamliujl/CMExam/fadb22c89beb1b7115dc36460ba792eb96b7b972/conference_material/presentation.pdf
--------------------------------------------------------------------------------
/docs/Annotation Guidelines.txt:
--------------------------------------------------------------------------------
1 | 1. Comprehensive Question Understanding: Prior to initiating the annotation process, meticulously comprehend the medical question, ensuring a holistic grasp of its context and significance.
2 | 2. Subject Categorization: Identify the precise subject or medical field that the question pertains to, such as cardiology, pediatrics, or pathology.
3 | 3. Principal Symptoms or Medical Conditions: Ascertain and pinpoint the primary symptoms or medical conditions expounded in the question.
4 | 4. Examination of Pertinent Factors: Scrutinize the question for any associated factors that might be present, including the severity of the ailment, its etiology, and patient history given in the question.
5 | 5. Appropriate Classification System Usage: Use the accurate classification system for annotation in alignment with the determined subject and symptoms. Suitable systems could encompass the 11th revision of the International Classification of Diseases (ICD-11), the Directory of Medical Institution Diagnostic and Therapeutic Categories (DMIDTC), and others.
6 | 6. Addressing Multiple Annotations: In scenarios where the question encompasses multiple symptoms or medical conditions, opt for the most related classification for annotation.
7 | 7. Ensuring High-Quality Annotations: Adhere to the guidelines and definitions within the chosen classification system. This diligence helps avert subjectivity and ambiguity, fostering precision in the annotations.
8 | 8. Navigating Queries and Uncertainties: Should any doubts or uncertainties emerge during the annotation process, consult the official documents and glossaries of the chosen classification system. Engaging in discussions with professionals is also advised to achieve clarity.
9 | 9. Resolving Discrepancies: When disagreements emerge between annotators, a collaborative discussion shall be initiated. The objective is to reach a consensus and unify the annotation decision.
--------------------------------------------------------------------------------
/docs/Comparison_with_MLEC-QA.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williamliujl/CMExam/fadb22c89beb1b7115dc36460ba792eb96b7b972/docs/Comparison_with_MLEC-QA.jpg
--------------------------------------------------------------------------------
/docs/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williamliujl/CMExam/fadb22c89beb1b7115dc36460ba792eb96b7b972/docs/example.png
--------------------------------------------------------------------------------
/docs/overall_comparison.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williamliujl/CMExam/fadb22c89beb1b7115dc36460ba792eb96b7b972/docs/overall_comparison.jpg
--------------------------------------------------------------------------------
/src/LoRA/finetune.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from typing import List
4 |
5 | import fire
6 | import torch
7 | import transformers
8 | from datasets import load_dataset
9 |
10 | """
11 | Unused imports:
12 | import torch.nn as nn
13 | import bitsandbytes as bnb
14 | """
15 |
16 | from peft import (
17 | LoraConfig,
18 | get_peft_model,
19 | get_peft_model_state_dict,
20 | prepare_model_for_int8_training,
21 | set_peft_model_state_dict,
22 | )
23 | from transformers import LlamaForCausalLM, LlamaTokenizer
24 |
25 | from utils.prompter import Prompter
26 | from utils.data_format_transform import filter_and_convert
27 |
28 |
29 | def train(
30 | # model/data params
31 | base_model: str = "medalpaca/medalpaca-7b", # the only required argument
32 | data_path: str = "../../data/train_prompt.json",
33 | valid_data_path: str = "../../data/val_prompt.json",
34 | output_dir: str = "./lora-medalpaca",
35 | prompt_id: str = '1',
36 | # training hyperparams
37 | batch_size: int = 128,
38 | micro_batch_size: int = 8,
39 | num_epochs: int = 2,
40 | learning_rate: float = 3e-4,
41 | cutoff_len: int = 256,
42 | val_set_size: int = 500,
43 | sample: int = None,
44 | # lora hyperparams
45 | lora_r: int = 8,
46 | lora_alpha: int = 16,
47 | lora_dropout: float = 0.05,
48 | lora_target_modules: List[str] = [
49 | "q_proj",
50 | "v_proj",
51 | ],
52 | # llm hyperparams
53 | train_on_inputs: bool = False, # if False, masks out inputs in loss
54 | group_by_length: bool = False, # faster, but produces an odd training loss curve
55 | # Others
56 | logging_steps: int = 8,
57 | eval_steps: int = 100,
58 | save_steps: int = 100,
59 | save_total_limit: int = 1000,
60 | # wandb params
61 | wandb_project: str = "llama_med",
62 | wandb_run_name: str = "",
63 | wandb_watch: str = "", # options: false | gradients | all
64 | wandb_log_model: str = "", # options: false | true
65 | resume_from_checkpoint: str = None, # either training checkpoint or final adapter
66 | prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
67 | ):
68 | if int(os.environ.get("LOCAL_RANK", 0)) == 0:
69 | print(
70 | f"Training model with params:\n"
71 | f"base_model: {base_model}\n"
72 | f"data_path: {data_path}\n"
73 | f"output_dir: {output_dir}\n"
74 | f"batch_size: {batch_size}\n"
75 | f"micro_batch_size: {micro_batch_size}\n"
76 | f"num_epochs: {num_epochs}\n"
77 | f"learning_rate: {learning_rate}\n"
78 | f"cutoff_len: {cutoff_len}\n"
79 | f"val_set_size: {val_set_size}\n"
80 | f"lora_r: {lora_r}\n"
81 | f"lora_alpha: {lora_alpha}\n"
82 | f"lora_dropout: {lora_dropout}\n"
83 | f"lora_target_modules: {lora_target_modules}\n"
84 | f"train_on_inputs: {train_on_inputs}\n"
85 | f"group_by_length: {group_by_length}\n"
86 | f"wandb_project: {wandb_project}\n"
87 | f"wandb_run_name: {wandb_run_name}\n"
88 | f"wandb_watch: {wandb_watch}\n"
89 | f"wandb_log_model: {wandb_log_model}\n"
90 | f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
91 | f"prompt template: {prompt_template_name}\n"
92 | )
93 | assert (
94 | base_model
95 | ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
96 | gradient_accumulation_steps = batch_size // micro_batch_size
97 |
98 | prompter = Prompter(prompt_template_name)
99 |
100 | device_map = "auto"
101 | world_size = int(os.environ.get("WORLD_SIZE", 1))
102 | ddp = world_size != 1
103 | if ddp:
104 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
105 | gradient_accumulation_steps = gradient_accumulation_steps // world_size
106 |
107 | # Check if parameter passed or if set within environ
108 | use_wandb = len(wandb_project) > 0 or (
109 | "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
110 | )
111 | # Only overwrite environ if wandb param passed
112 | if len(wandb_project) > 0:
113 | os.environ["WANDB_PROJECT"] = wandb_project
114 | if len(wandb_watch) > 0:
115 | os.environ["WANDB_WATCH"] = wandb_watch
116 | if len(wandb_log_model) > 0:
117 | os.environ["WANDB_LOG_MODEL"] = wandb_log_model
118 |
119 | model = LlamaForCausalLM.from_pretrained(
120 | base_model,
121 | load_in_8bit=True,
122 | torch_dtype=torch.float16,
123 | device_map=device_map,
124 | )
125 |
126 | tokenizer = LlamaTokenizer.from_pretrained(base_model)
127 |
128 | tokenizer.pad_token_id = (
129 | 0 # unk. we want this to be different from the eos token
130 | )
131 | tokenizer.padding_side = "left" # Allow batched inference
132 |
133 | def tokenize(prompt, add_eos_token=True):
134 | # there's probably a way to do this with the tokenizer settings
135 | # but again, gotta move fast
136 | result = tokenizer(
137 | prompt,
138 | truncation=True,
139 | max_length=cutoff_len,
140 | padding=False,
141 | return_tensors=None,
142 | )
143 | if (
144 | result["input_ids"][-1] != tokenizer.eos_token_id
145 | and len(result["input_ids"]) < cutoff_len
146 | and add_eos_token
147 | ):
148 | result["input_ids"].append(tokenizer.eos_token_id)
149 | result["attention_mask"].append(1)
150 |
151 | result["labels"] = result["input_ids"].copy()
152 |
153 | return result
154 |
155 | def generate_and_tokenize_prompt(data_point):
156 | full_prompt = prompter.generate_prompt(
157 | data_point["instruction"],
158 | data_point["input"],
159 | data_point["output"],
160 | )
161 | tokenized_full_prompt = tokenize(full_prompt)
162 | if not train_on_inputs:
163 | user_prompt = prompter.generate_prompt(
164 | data_point["instruction"], data_point["input"]
165 | )
166 | tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
167 | user_prompt_len = len(tokenized_user_prompt["input_ids"])
168 |
169 | tokenized_full_prompt["labels"] = [
170 | -100
171 | ] * user_prompt_len + tokenized_full_prompt["labels"][
172 | user_prompt_len:
173 | ] # could be sped up, probably
174 | return tokenized_full_prompt
175 |
176 | model = prepare_model_for_int8_training(model)
177 |
178 | config = LoraConfig(
179 | r=lora_r,
180 | lora_alpha=lora_alpha,
181 | target_modules=lora_target_modules,
182 | lora_dropout=lora_dropout,
183 | bias="none",
184 | task_type="CAUSAL_LM",
185 | )
186 | model = get_peft_model(model, config)
187 |
188 | filtered_train_data_path = filter_and_convert(data_path, prompt_id, sample)
189 | if filtered_train_data_path.endswith(".json") or filtered_train_data_path.endswith(".jsonl"):
190 | data = load_dataset("json", data_files=filtered_train_data_path)
191 | else:
192 | data = load_dataset(filtered_train_data_path)
193 | filtered_val_data_path = filter_and_convert(valid_data_path, prompt_id)
194 |
195 | if os.path.exists(filtered_val_data_path):
196 | if filtered_val_data_path.endswith(".json") or filtered_val_data_path.endswith(".jsonl"):
197 | valid_data = load_dataset("json", data_files=filtered_val_data_path)
198 | else:
199 | valid_data = load_dataset(data_path)
200 | else:
201 | valid_data = None
202 |
203 | if resume_from_checkpoint:
204 | # Check the available weights and load them
205 | checkpoint_name = os.path.join(
206 | resume_from_checkpoint, "pytorch_model.bin"
207 | ) # Full checkpoint
208 | if not os.path.exists(checkpoint_name):
209 | checkpoint_name = os.path.join(
210 | resume_from_checkpoint, "adapter_model.bin"
211 | ) # only LoRA model - LoRA config above has to fit
212 | resume_from_checkpoint = (
213 | False # So the trainer won't try loading its state
214 | )
215 | # The two files above have a different name depending on how they were saved, but are actually the same.
216 | if os.path.exists(checkpoint_name):
217 | print(f"Restarting from {checkpoint_name}")
218 | adapters_weights = torch.load(checkpoint_name)
219 | model = set_peft_model_state_dict(model, adapters_weights)
220 | else:
221 | print(f"Checkpoint {checkpoint_name} not found")
222 |
223 | model.print_trainable_parameters() # Be more transparent about the % of trainable params.
224 |
225 | if val_set_size > 0 and not valid_data:
226 | train_val = data["train"].train_test_split(
227 | test_size=val_set_size, shuffle=True, seed=2023
228 | )
229 | train_data = (
230 | train_val["train"].shuffle().map(generate_and_tokenize_prompt)
231 | )
232 | val_data = (
233 | train_val["test"].shuffle().map(generate_and_tokenize_prompt)
234 | )
235 | elif val_set_size > 0 and valid_data:
236 | train_data = (
237 | data["train"].shuffle(seed=2023).map(generate_and_tokenize_prompt)
238 | )
239 | val_sample = valid_data["train"].train_test_split(
240 | test_size=val_set_size, shuffle=True, seed=2023
241 | )
242 | val_data = (
243 | val_sample["test"].shuffle().map(generate_and_tokenize_prompt)
244 | )
245 | else:
246 | train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
247 | val_data = None
248 |
249 | if not ddp and torch.cuda.device_count() > 1:
250 | # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
251 | model.is_parallelizable = True
252 | model.model_parallel = True
253 |
254 | trainer = transformers.Trainer(
255 | model=model,
256 | train_dataset=train_data,
257 | eval_dataset=val_data,
258 | args=transformers.TrainingArguments(
259 | per_device_train_batch_size=micro_batch_size,
260 | gradient_accumulation_steps=gradient_accumulation_steps,
261 | warmup_ratio=0.1,
262 | num_train_epochs=num_epochs,
263 | learning_rate=learning_rate,
264 | fp16=True,
265 | logging_steps=logging_steps,
266 | optim="adamw_torch",
267 | evaluation_strategy="steps" if val_set_size > 0 else "no",
268 | save_strategy="steps",
269 | eval_steps=eval_steps if val_set_size > 0 else None,
270 | save_steps=save_steps,
271 | output_dir=output_dir,
272 | save_total_limit=save_total_limit,
273 | load_best_model_at_end=True if val_set_size > 0 else False,
274 | ddp_find_unused_parameters=False if ddp else None,
275 | group_by_length=group_by_length,
276 | report_to="wandb" if use_wandb else None,
277 | run_name=wandb_run_name if use_wandb else None,
278 | ),
279 | data_collator=transformers.DataCollatorForSeq2Seq(
280 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
281 | ),
282 | )
283 | model.config.use_cache = False
284 |
285 | old_state_dict = model.state_dict
286 | model.state_dict = (
287 | lambda self, *_, **__: get_peft_model_state_dict(
288 | self, old_state_dict()
289 | )
290 | ).__get__(model, type(model))
291 |
292 | if torch.__version__ >= "2" and sys.platform != "win32":
293 | model = torch.compile(model)
294 |
295 | trainer.train(resume_from_checkpoint=resume_from_checkpoint)
296 |
297 | model.save_pretrained(output_dir)
298 |
299 | print(
300 | "\n If there's a warning about missing keys above, please disregard :)"
301 | )
302 |
303 |
304 | if __name__ == "__main__":
305 | fire.Fire(train)
306 |
--------------------------------------------------------------------------------
/src/LoRA/generate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import fire
4 | import gradio as gr
5 | import torch
6 | import transformers
7 | from peft import PeftModel
8 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
9 |
10 | from utils.prompter import Prompter
11 |
12 | if torch.cuda.is_available():
13 | device = "cuda"
14 | else:
15 | device = "cpu"
16 |
17 | try:
18 | if torch.backends.mps.is_available():
19 | device = "mps"
20 | except: # noqa: E722
21 | pass
22 |
23 |
24 | def main(
25 | load_8bit: bool = False,
26 | base_model: str = "decapoda-research/llama-7b-hf",
27 | lora_weights: str = "tloen/alpaca-lora-7b",
28 | prompt_template: str = "med_template", # The prompt template to use, will default to alpaca.
29 | server_name: str = "0.0.0.0", # Allows to listen on all interfaces by providing '0.0.0.0'
30 | share_gradio: bool = True,
31 | ):
32 | assert (
33 | base_model
34 | ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
35 |
36 | prompter = Prompter(prompt_template)
37 | tokenizer = LlamaTokenizer.from_pretrained(base_model)
38 | if device == "cuda":
39 | model = LlamaForCausalLM.from_pretrained(
40 | base_model,
41 | load_in_8bit=load_8bit,
42 | torch_dtype=torch.float16,
43 | device_map="auto",
44 | )
45 | model = PeftModel.from_pretrained(
46 | model,
47 | lora_weights,
48 | torch_dtype=torch.float16,
49 | )
50 | elif device == "mps":
51 | model = LlamaForCausalLM.from_pretrained(
52 | base_model,
53 | device_map={"": device},
54 | torch_dtype=torch.float16,
55 | )
56 | model = PeftModel.from_pretrained(
57 | model,
58 | lora_weights,
59 | device_map={"": device},
60 | torch_dtype=torch.float16,
61 | )
62 | else:
63 | model = LlamaForCausalLM.from_pretrained(
64 | base_model, device_map={"": device}, low_cpu_mem_usage=True
65 | )
66 | model = PeftModel.from_pretrained(
67 | model,
68 | lora_weights,
69 | device_map={"": device},
70 | )
71 |
72 | # unwind broken decapoda-research config
73 | model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
74 | model.config.bos_token_id = 1
75 | model.config.eos_token_id = 2
76 |
77 | if not load_8bit:
78 | model.half() # seems to fix bugs for some users.
79 |
80 | model.eval()
81 | if torch.__version__ >= "2" and sys.platform != "win32":
82 | model = torch.compile(model)
83 |
84 | def evaluate(
85 | instruction,
86 | input=None,
87 | temperature=0.1,
88 | top_p=0.75,
89 | top_k=40,
90 | num_beams=4,
91 | max_new_tokens=128,
92 | **kwargs,
93 | ):
94 | prompt = prompter.generate_prompt(instruction, input)
95 | inputs = tokenizer(prompt, return_tensors="pt")
96 | input_ids = inputs["input_ids"].to(device)
97 | generation_config = GenerationConfig(
98 | temperature=temperature,
99 | top_p=top_p,
100 | top_k=top_k,
101 | num_beams=num_beams,
102 | **kwargs,
103 | )
104 | with torch.no_grad():
105 | generation_output = model.generate(
106 | input_ids=input_ids,
107 | generation_config=generation_config,
108 | return_dict_in_generate=True,
109 | output_scores=True,
110 | max_new_tokens=max_new_tokens,
111 | )
112 | s = generation_output.sequences[0]
113 | output = tokenizer.decode(s)
114 | return prompter.get_response(output)
115 |
116 | gr.Interface(
117 | fn=evaluate,
118 | inputs=[
119 | gr.components.Textbox(
120 | lines=2,
121 | label="Instruction",
122 | placeholder="Tell me about alpacas.",
123 | ),
124 | gr.components.Textbox(lines=2, label="Input", placeholder="none"),
125 | gr.components.Slider(
126 | minimum=0, maximum=1, value=0.1, label="Temperature"
127 | ),
128 | gr.components.Slider(
129 | minimum=0, maximum=1, value=0.75, label="Top p"
130 | ),
131 | gr.components.Slider(
132 | minimum=0, maximum=100, step=1, value=40, label="Top k"
133 | ),
134 | gr.components.Slider(
135 | minimum=1, maximum=4, step=1, value=4, label="Beams"
136 | ),
137 | gr.components.Slider(
138 | minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
139 | ),
140 | ],
141 | outputs=[
142 | gr.inputs.Textbox(
143 | lines=5,
144 | label="Output",
145 | )
146 | ],
147 | title="🦙🌲 Alpaca-LoRA",
148 | description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).", # noqa: E501
149 | ).launch(server_name=server_name, share=share_gradio)
150 | # Old testing code follows.
151 |
152 | """
153 | # testing code for readme
154 | for instruction in [
155 | "Tell me about alpacas.",
156 | "Tell me about the president of Mexico in 2019.",
157 | "Tell me about the king of France in 2019.",
158 | "List all Canadian provinces in alphabetical order.",
159 | "Write a Python program that prints the first 10 Fibonacci numbers.",
160 | "Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.", # noqa: E501
161 | "Tell me five words that rhyme with 'shock'.",
162 | "Translate the sentence 'I have no mouth but I must scream' into Spanish.",
163 | "Count up from 1 to 500.",
164 | ]:
165 | print("Instruction:", instruction)
166 | print("Response:", evaluate(instruction))
167 | print()
168 | """
169 |
170 |
171 | if __name__ == "__main__":
172 | fire.Fire(main)
173 |
--------------------------------------------------------------------------------
/src/LoRA/infer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import fire
5 | import torch
6 | import pandas as pd
7 | from peft import PeftModel
8 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
9 | from utils.prompter import Prompter
10 | from tqdm import tqdm
11 |
12 | device = "cuda" if torch.cuda.is_available() else "cpu"
13 |
14 | class InferenceEngine:
15 |
16 | def __init__(self):
17 | self.device = device
18 |
19 | def load_instruction(self, instruct_dir):
20 | input_data = []
21 | with open(instruct_dir, "r") as f:
22 | lines = f.readlines()
23 | for line in lines:
24 | line = line.strip()
25 | d = json.loads(line)
26 | input_data.append(d)
27 | return input_data
28 |
29 | def load_instruction_from_csv(self, instruct_dir, prompt_idx='all'):
30 | input_data = []
31 | df = pd.read_csv(instruct_dir, dtype='str')
32 | if prompt_idx!='all':
33 | df = df[df['prompt_idx'] == str(prompt_idx)]
34 | dict_from_df = df.to_dict(orient='index')
35 | for key,value in dict_from_df.items():
36 | data = {}
37 | data['output'] = value['completion'].strip()
38 | data['instruction'] = value['prompt'].strip()
39 | input_data.append(data)
40 | return input_data, df
41 |
42 | def evaluate(self,
43 | batch,
44 | input=None,
45 | **kwargs,
46 | ):
47 | prompts = [self.prompter.generate_prompt(data["instruction"], input) for data in batch]
48 | inputs = self.tokenizer(prompts, return_tensors="pt", padding=True).to(device)
49 | generation_config = GenerationConfig(
50 | temperature=self.temperature,
51 | top_p=self.top_p,
52 | top_k=self.top_k,
53 | num_beams=self.num_beams,
54 | **kwargs,
55 | )
56 | with torch.no_grad():
57 | generation_output = self.model.generate(
58 | **inputs,
59 | generation_config=generation_config,
60 | # return_dict_in_generate=True,
61 | # output_scores=True,
62 | max_new_tokens=self.max_new_tokens,
63 | num_return_sequences=self.num_return_sequences,
64 | )
65 | outputs = self.tokenizer.batch_decode(generation_output, skip_special_tokens=True)
66 | return [self.prompter.get_response(output) for output in outputs]
67 |
68 | def infer_from_csv(self, instruct_dir, output_dir, prompt_id):
69 | input_data, df_ori = self.load_instruction_from_csv(instruct_dir, prompt_id)
70 | df_ori.reset_index(drop=True, inplace=True)
71 | col_name = 'model_result'
72 | batched_data = [input_data[i:i+self.batch_size] for i in range(0, len(input_data), self.batch_size)]
73 | model_output_dict = {col_name:[]}
74 | for batch in tqdm(batched_data):
75 | instructions = [data["instruction"] for data in batch]
76 | outputs = self.evaluate(batch)
77 | for i, output in enumerate(outputs):
78 | instruction = instructions[i]
79 | golden_output = batch[i]["output"]
80 | print("###infering###")
81 | print("###instruction###")
82 | print(instruction)
83 | print("###golden output###")
84 | print(golden_output)
85 | print("###model output###")
86 | print(output)
87 | model_output_dict[col_name].append(output)
88 | new_df = pd.DataFrame(model_output_dict)
89 | merged_df = pd.concat([df_ori, new_df], axis=1)
90 | merged_df.to_csv(output_dir + self.output_file_name, index=False)
91 |
92 | def run(self,
93 | load_8bit=False,
94 | base_model="medalpaca/medalpaca-7b",
95 | instruct_dir="../../data/test_prompt.csv",
96 | prompt_id="4",
97 | output_dir="output/",
98 | output_file_name="output.csv",
99 | use_lora=False,
100 | lora_weights="tloen/alpaca-lora-7b",
101 | prompt_template="med_template",
102 | batch_size=4,
103 | temperature=0.1,
104 | top_p=0.75,
105 | top_k=40,
106 | num_beams=4,
107 | max_new_tokens=32,
108 | num_return_sequences=1
109 | ):
110 | self.output_file_name = output_file_name
111 | self.prompter = Prompter(prompt_template)
112 | self.tokenizer = LlamaTokenizer.from_pretrained(base_model, padding_side="left")
113 | self.model = LlamaForCausalLM.from_pretrained(
114 | base_model,
115 | load_in_8bit=load_8bit,
116 | torch_dtype=torch.float16,
117 | device_map="auto",
118 | )
119 | self.batch_size = batch_size
120 | self.temperature = temperature
121 | self.top_p = top_p
122 | self.top_k = top_k
123 | self.num_beams = num_beams
124 | self.max_new_tokens = max_new_tokens
125 | self.num_return_sequences = num_return_sequences
126 |
127 | if use_lora:
128 | print(f"using lora {lora_weights}")
129 | self.model = PeftModel.from_pretrained(
130 | self.model,
131 | lora_weights,
132 | torch_dtype=torch.float16,
133 | )
134 | # unwind broken decapoda-research config
135 | self.model.config.pad_token_id = self.tokenizer.pad_token_id = 0 # unk
136 | self.model.config.bos_token_id = self.tokenizer.bos_token_id
137 | self.model.config.eos_token_id = self.tokenizer.eos_token_id
138 | if not load_8bit:
139 | self.model.half() # seems to fix bugs for some users.
140 |
141 | self.model.eval()
142 |
143 | if torch.__version__ >= "2" and sys.platform != "win32":
144 | self.model = torch.compile(self.model)
145 |
146 | if instruct_dir != "":
147 | filename, file_extension = os.path.splitext(instruct_dir)
148 | file_extension_without_dot = file_extension[1:]
149 | if file_extension_without_dot == 'json':
150 | self.infer_from_json(instruct_dir)
151 | elif file_extension_without_dot == 'csv':
152 | self.infer_from_csv(instruct_dir, output_dir, prompt_id)
153 | else:
154 | raise ValueError
155 | else:
156 | for instruction in [
157 | "我感冒了,怎么治疗",
158 | "一个患有肝衰竭综合征的病人,除了常见的临床表现外,还有哪些特殊的体征?",
159 | "急性阑尾炎和缺血性心脏病的多发群体有何不同?",
160 | "小李最近出现了心动过速的症状,伴有轻度胸痛。体检发现P-R间期延长,伴有T波低平和ST段异常",
161 | ]:
162 | print("Instruction:", instruction)
163 | print("Response:", self.evaluate(instruction))
164 | print()
165 |
166 | if __name__ == "__main__":
167 | fire.Fire(InferenceEngine().run)
168 |
--------------------------------------------------------------------------------
/src/LoRA/scripts/finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | prompt_template="med_template"
3 | prompt_id="1"
4 | num_epochs=10
5 | # LLaMA-CMExam
6 | exp_tag="LLaMA-CMExam"
7 | python finetune.py \
8 | --base_model 'decapoda-research/llama-7b-hf' \
9 | --data_path '../../data/train_prompt.json' \
10 | --valid_data_path '../../data/val_prompt.json' \
11 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \
12 | --prompt_template_name $prompt_template \
13 | --micro_batch_size 8 \
14 | --batch_size 128 \
15 | --wandb_run_name $exp_tag \
16 | --prompt_id $prompt_id \
17 | --num_epochs $num_epochs \
18 | --cutoff_len 256 \
19 | --learning_rate 3e-4 \
20 | --lora_r 8 \
21 | --lora_alpha 16
22 | # Alpaca-CMExam
23 | exp_tag="Alpaca-CMExam"
24 | python finetune.py \
25 | --base_model 'decapoda-research/llama-7b-hf' \
26 | --resume_from_checkpoint 'alpaca-lora-7b' \
27 | --data_path '../../data/train_prompt.json' \
28 | --valid_data_path '../../data/val_prompt.json' \
29 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \
30 | --prompt_template_name $prompt_template \
31 | --micro_batch_size 8 \
32 | --batch_size 128 \
33 | --wandb_run_name $exp_tag \
34 | --prompt_id $prompt_id \
35 | --num_epochs $num_epochs \
36 | --cutoff_len 256 \
37 | --learning_rate 3e-4 \
38 | --lora_r 16 \
39 | --lora_alpha 16 \
40 | --lora_target_modules='[q_proj,k_proj,v_proj,o_proj]'
41 | # Huatuo-CMExam
42 | exp_tag="Huatuo-CMExam"
43 | python finetune.py \
44 | --base_model 'decapoda-research/llama-7b-hf' \
45 | --resume_from_checkpoint 'lora-alpaca-med' \
46 | --data_path '../../data/train_prompt.json' \
47 | --valid_data_path '../../data/val_prompt.json' \
48 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \
49 | --prompt_template_name $prompt_template \
50 | --micro_batch_size 8 \
51 | --batch_size 128 \
52 | --wandb_run_name $exp_tag \
53 | --prompt_id $prompt_id \
54 | --num_epochs $num_epochs \
55 | --cutoff_len 256 \
56 | --learning_rate 3e-4 \
57 | --lora_r 8 \
58 | --lora_alpha 16
59 | # MedAlpaca-CMExam
60 | exp_tag="Medalpaca-CMExam"
61 | python finetune.py \
62 | --base_model 'medalpaca/medalpaca-7b' \
63 | --data_path '../../data/train_prompt.json' \
64 | --valid_data_path '../../data/val_prompt.json' \
65 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \
66 | --prompt_template_name $prompt_template \
67 | --micro_batch_size 8 \
68 | --batch_size 128 \
69 | --wandb_run_name $exp_tag \
70 | --prompt_id $prompt_id \
71 | --num_epochs $num_epochs \
72 | --cutoff_len 256 \
73 | --learning_rate 3e-4 \
74 | --lora_r 8 \
75 | --lora_alpha 16
76 | #
77 | prompt_id="4"
78 | num_epochs=1
79 | # LLaMA-CMExam
80 | exp_tag="LLaMA-CMExam"
81 | python finetune.py \
82 | --base_model 'decapoda-research/llama-7b-hf' \
83 | --data_path '../../data/train_prompt.json' \
84 | --valid_data_path '../../data/val_prompt.json' \
85 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \
86 | --prompt_template_name $prompt_template \
87 | --micro_batch_size 8 \
88 | --batch_size 128 \
89 | --wandb_run_name $exp_tag \
90 | --prompt_id $prompt_id \
91 | --num_epochs $num_epochs \
92 | --cutoff_len 256 \
93 | --learning_rate 3e-4 \
94 | --lora_r 8 \
95 | --lora_alpha 16
96 | # Alpaca-CMExam
97 | exp_tag="Alpaca-CMExam"
98 | python finetune.py \
99 | --base_model 'decapoda-research/llama-7b-hf' \
100 | --resume_from_checkpoint 'alpaca-lora-7b' \
101 | --data_path '../../data/train_prompt.json' \
102 | --valid_data_path '../../data/val_prompt.json' \
103 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \
104 | --prompt_template_name $prompt_template \
105 | --micro_batch_size 8 \
106 | --batch_size 128 \
107 | --wandb_run_name $exp_tag \
108 | --prompt_id $prompt_id \
109 | --num_epochs $num_epochs \
110 | --cutoff_len 256 \
111 | --learning_rate 3e-4 \
112 | --lora_r 16 \
113 | --lora_alpha 16 \
114 | --lora_target_modules='[q_proj,k_proj,v_proj,o_proj]'
115 | # Huatuo-CMExam
116 | exp_tag="Huatuo-CMExam"
117 | python finetune.py \
118 | --base_model 'decapoda-research/llama-7b-hf' \
119 | --resume_from_checkpoint 'lora-alpaca-med' \
120 | --data_path '../../data/train_prompt.json' \
121 | --valid_data_path '../../data/val_prompt.json' \
122 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \
123 | --prompt_template_name $prompt_template \
124 | --micro_batch_size 8 \
125 | --batch_size 128 \
126 | --wandb_run_name $exp_tag \
127 | --prompt_id $prompt_id \
128 | --num_epochs $num_epochs \
129 | --cutoff_len 256 \
130 | --learning_rate 3e-4 \
131 | --lora_r 8 \
132 | --lora_alpha 16
133 | # MedAlpaca-CMExam
134 | exp_tag="Medalpaca-CMExam"
135 | python finetune.py \
136 | --base_model 'medalpaca/medalpaca-7b' \
137 | --data_path '../../data/train_prompt.json' \
138 | --valid_data_path '../../data/val_prompt.json' \
139 | --output_dir './saved/lora-cmexam-'${exp_tag}'-'${prompt_id} \
140 | --prompt_template_name $prompt_template \
141 | --micro_batch_size 8 \
142 | --batch_size 128 \
143 | --wandb_run_name $exp_tag \
144 | --prompt_id $prompt_id \
145 | --num_epochs $num_epochs \
146 | --cutoff_len 256 \
147 | --learning_rate 3e-4 \
148 | --lora_r 8 \
149 | --lora_alpha 16
--------------------------------------------------------------------------------
/src/LoRA/scripts/infer_ori.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | # medalpaca prompt 1
3 | CUDA_VISIBLE_DEVICES=0 python infer.py \
4 | --base_model 'medalpaca/medalpaca-7b' \
5 | --use_lora False \
6 | --instruct_dir '../../data/test_prompt.csv' \
7 | --prompt_template 'med_template' \
8 | --output_file_name 'medalpaca_1.csv' \
9 | --prompt_id '1' \
10 | --batch_size 4 \
11 | --num_beams 1 \
12 | --max_new_tokens 64
13 | # medalpaca prompt 4
14 | CUDA_VISIBLE_DEVICES=0 python infer.py \
15 | --base_model 'medalpaca/medalpaca-7b' \
16 | --use_lora False \
17 | --instruct_dir '../../data/test_prompt.csv' \
18 | --prompt_template 'med_template' \
19 | --output_file_name 'medalpaca_4.csv' \
20 | --prompt_id '4' \
21 | --batch_size 2 \
22 | --num_beams 4 \
23 | --max_new_tokens 256
--------------------------------------------------------------------------------
/src/LoRA/scripts/infer_sft.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | # LLaMA-CMExam prompt 1
3 | model_name='LLaMA-CMExam'
4 | prompt_id='1'
5 | CUDA_VISIBLE_DEVICES=0 python infer.py \
6 | --base_model 'decapoda-research/llama-7b-hf' \
7 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \
8 | --use_lora True \
9 | --instruct_dir '../../data/test_prompt.csv' \
10 | --prompt_template 'med_template' \
11 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \
12 | --prompt_id ${prompt_id} \
13 | --batch_size 4 \
14 | --num_beams 1 \
15 | --max_new_tokens 32
16 | # LLaMA-CMExam prompt 4
17 | model_name='LLaMA-CMExam'
18 | prompt_id='4'
19 | CUDA_VISIBLE_DEVICES=0 python infer.py \
20 | --base_model 'decapoda-research/llama-7b-hf' \
21 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \
22 | --use_lora True \
23 | --instruct_dir '../../data/test_prompt.csv' \
24 | --prompt_template 'med_template' \
25 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \
26 | --prompt_id ${prompt_id} \
27 | --batch_size 4 \
28 | --num_beams 4 \
29 | --max_new_tokens 256
30 | # Alpaca-CMExam prompt 1
31 | model_name='Alpaca-CMExam'
32 | prompt_id='1'
33 | CUDA_VISIBLE_DEVICES=0 python infer.py \
34 | --base_model 'decapoda-research/llama-7b-hf' \
35 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \
36 | --use_lora True \
37 | --instruct_dir '../../data/test_prompt.csv' \
38 | --prompt_template 'med_template' \
39 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \
40 | --prompt_id ${prompt_id} \
41 | --batch_size 4 \
42 | --num_beams 1 \
43 | --max_new_tokens 32
44 | # Alpaca-CMExam prompt 4
45 | model_name='Alpaca-CMExam'
46 | prompt_id='4'
47 | CUDA_VISIBLE_DEVICES=0 python infer.py \
48 | --base_model 'decapoda-research/llama-7b-hf' \
49 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \
50 | --use_lora True \
51 | --instruct_dir '../../data/test_prompt.csv' \
52 | --prompt_template 'med_template' \
53 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \
54 | --prompt_id ${prompt_id} \
55 | --batch_size 4 \
56 | --num_beams 4 \
57 | --max_new_tokens 256
58 | # Huatuo-CMExam prompt 1
59 | model_name='Huatuo-CMExam'
60 | prompt_id='1'
61 | CUDA_VISIBLE_DEVICES=0 python infer.py \
62 | --base_model 'decapoda-research/llama-7b-hf' \
63 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \
64 | --use_lora True \
65 | --instruct_dir '../../data/test_prompt.csv' \
66 | --prompt_template 'med_template' \
67 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \
68 | --prompt_id ${prompt_id} \
69 | --batch_size 4 \
70 | --num_beams 1 \
71 | --max_new_tokens 32
72 | # Huatuo-CMExam prompt 4
73 | model_name='Huatuo-CMExam'
74 | prompt_id='4'
75 | CUDA_VISIBLE_DEVICES=0 python infer.py \
76 | --base_model 'decapoda-research/llama-7b-hf' \
77 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \
78 | --use_lora True \
79 | --instruct_dir '../../data/test_prompt.csv' \
80 | --prompt_template 'med_template' \
81 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \
82 | --prompt_id ${prompt_id} \
83 | --batch_size 4 \
84 | --num_beams 4 \
85 | --max_new_tokens 256
86 | # Medalpaca-CMExam prompt 1
87 | model_name='Medalpaca-CMExam'
88 | prompt_id='1'
89 | CUDA_VISIBLE_DEVICES=0 python infer.py \
90 | --base_model 'medalpaca/medalpaca-7b' \
91 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \
92 | --use_lora True \
93 | --instruct_dir '../../data/test_prompt.csv' \
94 | --prompt_template 'med_template' \
95 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \
96 | --prompt_id ${prompt_id} \
97 | --batch_size 4 \
98 | --num_beams 1 \
99 | --max_new_tokens 32
100 | # Medalpaca-CMExam prompt 4
101 | model_name='Medalpaca-CMExam'
102 | prompt_id='4'
103 | CUDA_VISIBLE_DEVICES=0 python infer.py \
104 | --base_model 'medalpaca/medalpaca-7b' \
105 | --lora_weights './saved/lora-cmexam-'${model_name}'-'${prompt_id}'/' \
106 | --use_lora True \
107 | --instruct_dir '../../data/test_prompt.csv' \
108 | --prompt_template 'med_template' \
109 | --output_file_name ${model_name}'-'${prompt_id}'.csv' \
110 | --prompt_id ${prompt_id} \
111 | --batch_size 4 \
112 | --num_beams 4 \
113 | --max_new_tokens 256
--------------------------------------------------------------------------------
/src/LoRA/templates/README.md:
--------------------------------------------------------------------------------
1 | # Prompt templates
2 |
3 | This directory contains template styles for the prompts used to finetune LoRA models.
4 |
5 | ## Format
6 |
7 | A template is described via a JSON file with the following keys:
8 |
9 | - `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders.
10 | - `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders.
11 | - `description`: A short description of the template, with possible use cases.
12 | - `response_split`: The text to use as separator when cutting real response from the model output.
13 |
14 | No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest.
15 |
16 | ## Example template
17 |
18 | The default template, used unless otherwise specified, is `alpaca.json`
19 |
20 | ```json
21 | {
22 | "description": "Template used by Alpaca-LoRA.",
23 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
24 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
25 | "response_split": "### Response:"
26 | }
27 |
28 | ```
29 |
30 | ## Current templates
31 |
32 | ### alpaca
33 |
34 | Default template used for generic LoRA fine tunes so far.
35 |
36 | ### alpaca_legacy
37 |
38 | Legacy template used by the original alpaca repo, with no `\n` after the response field. Kept for reference and experiments.
39 |
40 | ### alpaca_short
41 |
42 | A trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome.
43 |
44 | ### vigogne
45 |
46 | The default alpaca template, translated to french. This template was used to train the "Vigogne" LoRA and is to be used to query it, or for extra fine tuning.
47 |
--------------------------------------------------------------------------------
/src/LoRA/templates/med_template.json:
--------------------------------------------------------------------------------
1 | {
2 | "description": "Template used by Med Instruction Tuning",
3 | "prompt_input": "{instruction}\n### 回答:\n",
4 | "prompt_no_input": "{instruction}\n### 回答:\n",
5 | "response_split": "### 回答:"
6 | }
--------------------------------------------------------------------------------
/src/LoRA/utils/README.md:
--------------------------------------------------------------------------------
1 | # Directory for helpers modules
2 |
3 | ## prompter.py
4 |
5 | Prompter class, a template manager.
6 |
7 | `from utils.prompter import Prompter`
--------------------------------------------------------------------------------
/src/LoRA/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williamliujl/CMExam/fadb22c89beb1b7115dc36460ba792eb96b7b972/src/LoRA/utils/__init__.py
--------------------------------------------------------------------------------
/src/LoRA/utils/data_format_transform.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2023/5/17 01:01
3 | # @Author : Peilin Zhou
4 | # @FileName: data_format_transform.py
5 | # @Software: PyCharm
6 | # @E-mail : zhoupl@pku.edu.cn
7 | import json
8 | import os
9 | import argparse
10 |
11 | def filter_and_convert(input_file, target_id, sample=None):
12 | filtered_data = []
13 | target_id = str(target_id)
14 |
15 | output_file_name = os.path.splitext(input_file)[0]
16 | if target_id is not None:
17 | output_file_name += '_' + str(target_id)
18 | else:
19 | output_file_name += '_all'
20 |
21 | with open(input_file, 'r', encoding='utf-8') as f:
22 | for line in f:
23 | data = json.loads(line)
24 | if target_id is None or target_id=='all' or data['id'] == target_id:
25 | filtered_data.append({
26 | 'instruction': data['prompt'],
27 | 'input': '',
28 | 'output': data['completion']
29 | })
30 |
31 | output_file = output_file_name + '.json'
32 |
33 | with open(output_file, 'w', encoding='utf-8') as f:
34 | if sample:
35 | for data in filtered_data[:sample]:
36 | f.write(json.dumps(data, ensure_ascii=False) + '\n')
37 | else:
38 | for data in filtered_data:
39 | f.write(json.dumps(data, ensure_ascii=False) + '\n')
40 |
41 | print(f"Filtered file is saved to:{output_file}")
42 | return output_file
43 |
44 | if __name__ == '__main__':
45 | parser = argparse.ArgumentParser(description='Filter and convert JSON file.')
46 | parser.add_argument('input_file', type=str, help='path to the input JSON file', default='data/train_prompt.json')
47 | parser.add_argument('target_id', type=str, nargs='?', default=None, help='target ID for filtering (optional)')
48 | args = parser.parse_args()
49 |
50 | input_file_path = args.input_file
51 | target_id = args.target_id
52 |
53 | filter_and_convert(input_file_path, target_id)
--------------------------------------------------------------------------------
/src/LoRA/utils/prompter.py:
--------------------------------------------------------------------------------
1 | """
2 | A dedicated helper to manage templates and prompt building.
3 | """
4 |
5 | import json
6 | import os.path as osp
7 | from typing import Union
8 |
9 |
10 | class Prompter(object):
11 | __slots__ = ("template", "_verbose")
12 |
13 | def __init__(self, template_name: str = "", verbose: bool = False):
14 | self._verbose = verbose
15 | if not template_name:
16 | # Enforce the default here, so the constructor can be called with '' and will not break.
17 | template_name = "alpaca"
18 | file_name = osp.join("./templates", f"{template_name}.json")
19 | if not osp.exists(file_name):
20 | raise ValueError(f"Can't read {file_name}")
21 | with open(file_name) as fp:
22 | self.template = json.load(fp)
23 | if self._verbose:
24 | print(
25 | f"Using prompt template {template_name}: {self.template['description']}"
26 | )
27 |
28 | def generate_prompt(
29 | self,
30 | instruction: str,
31 | input: Union[None, str] = None,
32 | label: Union[None, str] = None,
33 | ) -> str:
34 | # returns the full prompt from instruction and optional input
35 | # if a label (=response, =output) is provided, it's also appended.
36 | if input:
37 | res = self.template["prompt_input"].format(
38 | instruction=instruction, input=input
39 | )
40 | else:
41 | res = self.template["prompt_no_input"].format(
42 | instruction=instruction
43 | )
44 | if label:
45 | res = f"{res}{label}"
46 | if self._verbose:
47 | print(res)
48 | return res
49 |
50 | def get_response(self, output: str) -> str:
51 | return output.split(self.template["response_split"])[1].strip()
52 |
--------------------------------------------------------------------------------
/src/evaluation/evaluate/bleu.py:
--------------------------------------------------------------------------------
1 | """
2 | Borrowed from https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py
3 |
4 | Python implementation of BLEU and smooth-BLEU.
5 |
6 | This module provides a Python implementation of BLEU and smooth-BLEU.
7 | Smooth BLEU is computed following the method outlined in the paper:
8 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic
9 | evaluation metrics for machine translation. COLING 2004.
10 | """
11 |
12 | import collections
13 | import math
14 |
15 |
16 | def _get_ngrams(segment, max_order):
17 | """Extracts all n-grams upto a given maximum order from an input segment.
18 |
19 | Args:
20 | segment: text segment from which n-grams will be extracted.
21 | max_order: maximum length in tokens of the n-grams returned by this
22 | methods.
23 |
24 | Returns:
25 | The Counter containing all n-grams upto max_order in segment
26 | with a count of how many times each n-gram occurred.
27 | """
28 | ngram_counts = collections.Counter()
29 | for order in range(1, max_order + 1):
30 | for i in range(0, len(segment) - order + 1):
31 | ngram = tuple(segment[i:i+order])
32 | ngram_counts[ngram] += 1
33 | return ngram_counts
34 |
35 |
36 | def compute_bleu(reference_corpus, translation_corpus, max_order=4,
37 | smooth=False):
38 | """Computes BLEU score of translated segments against one or more references.
39 |
40 | Args:
41 | reference_corpus: list of lists of references for each translation. Each
42 | reference should be tokenized into a list of tokens.
43 | translation_corpus: list of translations to score. Each translation
44 | should be tokenized into a list of tokens.
45 | max_order: Maximum n-gram order to use when computing BLEU score.
46 | smooth: Whether or not to apply Lin et al. 2004 smoothing.
47 |
48 | Returns:
49 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
50 | precisions and brevity penalty.
51 | """
52 | matches_by_order = [0] * max_order
53 | possible_matches_by_order = [0] * max_order
54 | reference_length = 0
55 | translation_length = 0
56 | for (references, translation) in zip(reference_corpus,
57 | translation_corpus):
58 | reference_length += min(len(r) for r in references)
59 | translation_length += len(translation)
60 |
61 | merged_ref_ngram_counts = collections.Counter()
62 | for reference in references:
63 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
64 | translation_ngram_counts = _get_ngrams(translation, max_order)
65 | overlap = translation_ngram_counts & merged_ref_ngram_counts
66 | for ngram in overlap:
67 | matches_by_order[len(ngram)-1] += overlap[ngram]
68 | for order in range(1, max_order+1):
69 | possible_matches = len(translation) - order + 1
70 | if possible_matches > 0:
71 | possible_matches_by_order[order-1] += possible_matches
72 |
73 | precisions = [0] * max_order
74 | for i in range(0, max_order):
75 | if smooth:
76 | precisions[i] = ((matches_by_order[i] + 1.) /
77 | (possible_matches_by_order[i] + 1.))
78 | else:
79 | if possible_matches_by_order[i] > 0:
80 | precisions[i] = (float(matches_by_order[i]) /
81 | possible_matches_by_order[i])
82 | else:
83 | precisions[i] = 0.0
84 |
85 | if min(precisions) > 0:
86 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
87 | geo_mean = math.exp(p_log_sum)
88 | else:
89 | geo_mean = 0
90 |
91 | ratio = float(translation_length) / reference_length
92 |
93 | if ratio > 1.0:
94 | bp = 1.
95 | else:
96 | bp = math.exp(1 - 1. / ratio)
97 |
98 | bleu = geo_mean * bp
99 |
100 | return (bleu, precisions, bp, ratio, translation_length, reference_length)
101 |
--------------------------------------------------------------------------------
/src/evaluation/evaluate/metrics4rec.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import math
4 | import numpy as np
5 | import heapq
6 |
7 |
8 | def evaluate_old(predict, groundtruth, topk=10):
9 | """[Deprecated] Compute metrics for predicted recommendations.
10 | Args:
11 | predict: a dict with key = and value =
12 | groundtruth: a dict with key = and value = .
13 | Returns:
14 | Dict of metrics.
15 | """
16 | invalid_users = []
17 |
18 | # Compute metrics
19 | precisions, recalls, ndcgs, hits = [], [], [], []
20 | for uid in groundtruth:
21 | if uid not in predict or len(predict[uid]) < topk:
22 | invalid_users.append(uid)
23 | continue
24 | pred_list, rel_set = predict[uid][:topk], groundtruth[uid]
25 | if len(pred_list) == 0:
26 | continue
27 |
28 | dcg = 0.0
29 | hit_num = 0.0
30 | for i in range(len(pred_list)):
31 | if pred_list[i] in rel_set:
32 | dcg += 1.0 / (math.log(i + 2) / math.log(2))
33 | hit_num += 1
34 | # idcg
35 | idcg = 0.0
36 | for i in range(min(len(rel_set), len(pred_list))):
37 | idcg += 1.0 / (math.log(i + 2) / math.log(2))
38 | ndcg = dcg / idcg
39 | recall = hit_num / len(rel_set)
40 | precision = hit_num / len(pred_list)
41 | hit = 1.0 if hit_num > 0.0 else 0.0
42 |
43 | ndcgs.append(ndcg)
44 | recalls.append(recall)
45 | precisions.append(precision)
46 | hits.append(hit)
47 |
48 | avg_precision = np.mean(precisions)
49 | avg_recall = np.mean(recalls)
50 | avg_ndcg = np.mean(ndcgs)
51 | avg_hit = np.mean(hits)
52 | msg = "NDCG={:.4f} | Recall={:.4f} | HR={:.4f} | Precision={:.4f} | Invalid users={}".format(
53 | avg_ndcg, avg_recall, avg_hit, avg_precision, len(invalid_users)
54 | )
55 | print(msg)
56 | return msg
57 |
58 |
59 | def recall_at_k(r, k, all_pos_num):
60 | r = np.asarray(r)[:k]
61 | return np.sum(r) / all_pos_num
62 |
63 |
64 | def hit_at_k(r, k):
65 | r = np.asarray(r)[:k]
66 | if np.sum(r) > 0:
67 | return 1.0
68 | else:
69 | return 0.0
70 |
71 |
72 | def mean_reciprocal_rank(rs):
73 | """Score is reciprocal of the rank of the first relevant item
74 | First element is 'rank 1'. Relevance is binary (nonzero is relevant).
75 | Example from http://en.wikipedia.org/wiki/Mean_reciprocal_rank
76 | >>> rs = [[0, 0, 1], [0, 1, 0], [1, 0, 0]]
77 | >>> mean_reciprocal_rank(rs)
78 | 0.61111111111111105
79 | >>> rs = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]])
80 | >>> mean_reciprocal_rank(rs)
81 | 0.5
82 | >>> rs = [[0, 0, 0, 1], [1, 0, 0], [1, 0, 0]]
83 | >>> mean_reciprocal_rank(rs)
84 | 0.75
85 | Args:
86 | rs: Iterator of relevance scores (list or numpy) in rank order
87 | (first element is the first item)
88 | Returns:
89 | Mean reciprocal rank
90 | """
91 | rs = (np.asarray(r).nonzero()[0] for r in rs)
92 | return np.mean([1.0 / (r[0] + 1) if r.size else 0.0 for r in rs])
93 |
94 |
95 | def r_precision(r):
96 | """Score is precision after all relevant documents have been retrieved
97 | Relevance is binary (nonzero is relevant).
98 | >>> r = [0, 0, 1]
99 | >>> r_precision(r)
100 | 0.33333333333333331
101 | >>> r = [0, 1, 0]
102 | >>> r_precision(r)
103 | 0.5
104 | >>> r = [1, 0, 0]
105 | >>> r_precision(r)
106 | 1.0
107 | Args:
108 | r: Relevance scores (list or numpy) in rank order
109 | (first element is the first item)
110 | Returns:
111 | R Precision
112 | """
113 | r = np.asarray(r) != 0
114 | z = r.nonzero()[0]
115 | if not z.size:
116 | return 0.0
117 | return np.mean(r[: z[-1] + 1])
118 |
119 |
120 | def precision_at_k(r, k):
121 | """Score is precision @ k
122 | Relevance is binary (nonzero is relevant).
123 | >>> r = [0, 0, 1]
124 | >>> precision_at_k(r, 1)
125 | 0.0
126 | >>> precision_at_k(r, 2)
127 | 0.0
128 | >>> precision_at_k(r, 3)
129 | 0.33333333333333331
130 | >>> precision_at_k(r, 4)
131 | Traceback (most recent call last):
132 | File "", line 1, in ?
133 | ValueError: Relevance score length < k
134 | Args:
135 | r: Relevance scores (list or numpy) in rank order
136 | (first element is the first item)
137 | Returns:
138 | Precision @ k
139 | Raises:
140 | ValueError: len(r) must be >= k
141 | """
142 | assert k >= 1
143 | r = np.asarray(r)[:k] != 0
144 | if r.size != k:
145 | raise ValueError("Relevance score length < k")
146 | return np.mean(r)
147 |
148 |
149 | def average_precision(r):
150 | """Score is average precision (area under PR curve)
151 | Relevance is binary (nonzero is relevant).
152 | >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1]
153 | >>> delta_r = 1. / sum(r)
154 | >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in enumerate(r) if y])
155 | 0.7833333333333333
156 | >>> average_precision(r)
157 | 0.78333333333333333
158 | Args:
159 | r: Relevance scores (list or numpy) in rank order
160 | (first element is the first item)
161 | Returns:
162 | Average precision
163 | """
164 | r = np.asarray(r) != 0
165 | out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]]
166 | if not out:
167 | return 0.0
168 | return np.mean(out)
169 |
170 |
171 | def mean_average_precision(rs):
172 | """Score is mean average precision
173 | Relevance is binary (nonzero is relevant).
174 | >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1]]
175 | >>> mean_average_precision(rs)
176 | 0.78333333333333333
177 | >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1], [0]]
178 | >>> mean_average_precision(rs)
179 | 0.39166666666666666
180 | Args:
181 | rs: Iterator of relevance scores (list or numpy) in rank order
182 | (first element is the first item)
183 | Returns:
184 | Mean average precision
185 | """
186 | return np.mean([average_precision(r) for r in rs])
187 |
188 |
189 | def dcg_at_k(r, k, method=1):
190 | """Score is discounted cumulative gain (dcg)
191 | Relevance is positive real values. Can use binary
192 | as the previous methods.
193 | Example from
194 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf
195 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0]
196 | >>> dcg_at_k(r, 1)
197 | 3.0
198 | >>> dcg_at_k(r, 1, method=1)
199 | 3.0
200 | >>> dcg_at_k(r, 2)
201 | 5.0
202 | >>> dcg_at_k(r, 2, method=1)
203 | 4.2618595071429155
204 | >>> dcg_at_k(r, 10)
205 | 9.6051177391888114
206 | >>> dcg_at_k(r, 11)
207 | 9.6051177391888114
208 | Args:
209 | r: Relevance scores (list or numpy) in rank order
210 | (first element is the first item)
211 | k: Number of results to consider
212 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...]
213 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...]
214 | Returns:
215 | Discounted cumulative gain
216 | """
217 | r = np.asfarray(r)[:k]
218 | if r.size:
219 | if method == 0:
220 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))
221 | elif method == 1:
222 | return np.sum(r / np.log2(np.arange(2, r.size + 2)))
223 | else:
224 | raise ValueError("method must be 0 or 1.")
225 | return 0.0
226 |
227 |
228 | def ndcg_at_k(r, k, method=1):
229 | """Score is normalized discounted cumulative gain (ndcg)
230 | Relevance is positive real values. Can use binary
231 | as the previous methods.
232 | Example from
233 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf
234 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0]
235 | >>> ndcg_at_k(r, 1)
236 | 1.0
237 | >>> r = [2, 1, 2, 0]
238 | >>> ndcg_at_k(r, 4)
239 | 0.9203032077642922
240 | >>> ndcg_at_k(r, 4, method=1)
241 | 0.96519546960144276
242 | >>> ndcg_at_k([0], 1)
243 | 0.0
244 | >>> ndcg_at_k([1], 2)
245 | 1.0
246 | Args:
247 | r: Relevance scores (list or numpy) in rank order
248 | (first element is the first item)
249 | k: Number of results to consider
250 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...]
251 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...]
252 | Returns:
253 | Normalized discounted cumulative gain
254 | """
255 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method)
256 | if not dcg_max:
257 | return 0.0
258 | return dcg_at_k(r, k, method) / dcg_max
259 |
260 |
261 | def evaluate_once(topk_preds, groundtruth):
262 | """Evaluate one user performance.
263 | Args:
264 | topk_preds: list of . length of the list is topK.
265 | groundtruth: list of .
266 | Returns:
267 | dict of metrics.
268 | """
269 | gt_set = set(groundtruth)
270 | topk = len(topk_preds)
271 | rel = []
272 | for iid in topk_preds:
273 | if iid in gt_set:
274 | rel.append(1)
275 | else:
276 | rel.append(0)
277 | return {
278 | "precision@k": precision_at_k(rel, topk),
279 | "recall@k": recall_at_k(rel, topk, len(gt_set)),
280 | "ndcg@k": ndcg_at_k(rel, topk, 1),
281 | "hit@k": hit_at_k(rel, topk),
282 | "ap": average_precision(rel),
283 | "rel": rel,
284 | }
285 |
286 |
287 | def evaluate_all(user_item_scores, groudtruth, topk=10):
288 | """Evaluate all user-items performance.
289 | Args:
290 | user_item_scores: dict with key = , value = .
291 | Make sure larger score means better recommendation.
292 | groudtruth: dict with key = , value = list of .
293 | topk: int
294 | Returns:
295 | """
296 | avg_prec, avg_recall, avg_ndcg, avg_hit = 0.0, 0.0, 0.0, 0.0
297 | rs = []
298 | cnt = 0
299 | for uid in user_item_scores:
300 | # [Important] Use shuffle to break ties!!!
301 | ui_scores = list(user_item_scores[uid].items())
302 | np.random.shuffle(ui_scores) # break ties
303 | # topk_preds = heapq.nlargest(topk, user_item_scores[uid], key=user_item_scores[uid].get) # list of k
304 | topk_preds = heapq.nlargest(topk, ui_scores, key=lambda x: x[1]) # list of k tuples
305 | topk_preds = [x[0] for x in topk_preds] # list of k
306 | # print(topk_preds, groudtruth[uid])
307 | result = evaluate_once(topk_preds, groudtruth[uid])
308 | avg_prec += result["precision@k"]
309 | avg_recall += result["recall@k"]
310 | avg_ndcg += result["ndcg@k"]
311 | avg_hit += result["hit@k"]
312 | rs.append(result["rel"])
313 | cnt += 1
314 |
315 | # [CAVEAT] Following code calculates metrics for each gt item.
316 | # for iid in groudtruth[uid]:
317 | # result = evaluate_once(topk_preds, [iid])
318 | # avg_prec += result["precision@k"]
319 | # avg_recall += result["recall@k"]
320 | # avg_ndcg += result["ndcg@k"]
321 | # avg_hit += result["hit@k"]
322 | # rs.append(result["rel"])
323 | # cnt += 1
324 |
325 | avg_prec = avg_prec / cnt
326 | avg_recall = avg_recall / cnt
327 | avg_ndcg = avg_ndcg / cnt
328 | avg_hit = avg_hit / cnt
329 | map_ = mean_average_precision(rs)
330 | mrr = mean_reciprocal_rank(rs)
331 | msg = "\nNDCG@{}\tRec@{}\tHits@{}\tPrec@{}\tMAP@{}\tMRR@{}".format(topk, topk, topk, topk, topk, topk)
332 | msg += "\n{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}".format(avg_ndcg, avg_recall, avg_hit, avg_prec, map_, mrr)
333 | # msg = "NDCG@{}\tRec@{}\tMAP@{}".format(topk, topk, topk)
334 | # msg += "\n{:.4f}\t{:.4f}\t{:.4f}".format(avg_ndcg, avg_recall, map)
335 | print(msg)
336 | res = {
337 | 'ndcg': avg_ndcg,
338 | 'map': map_,
339 | 'recall': avg_recall,
340 | 'precision': avg_prec,
341 | 'mrr': mrr,
342 | 'hit': avg_hit,
343 | }
344 | return msg, res
345 |
346 |
347 | def main():
348 | ui_scores = {
349 | 1: {11: 3, 12: 4, 13: 5, 14: 6, 15: 7},
350 | # 2: {11: 3, 12: 4, 13: 5, 14: 6, 15: 7},
351 | # 3: {11: 3, 12: 4, 13: 5, 14: 6, 15: 7},
352 | # 4: {11: 3, 12: 4, 13: 5, 14: 6, 15: 7},
353 | # 5: {11: 3, 12: 4, 13: 5, 14: 6, 15: 7},
354 | }
355 | gt = {
356 | 1: [11, 15],
357 | # 2: [12, 13],
358 | # 3: [11, 14],
359 | # 4: [12, 15],
360 | # 5: [11],
361 | }
362 | evaluate_all(ui_scores, gt, 5)
363 |
364 | # pred = {}
365 | # for uid in ui_scores:
366 | # pred[uid] = heapq.nlargest(3, ui_scores[uid], key=ui_scores[uid].get)
367 | # evaluate_old(pred, gt, 3)
368 |
369 |
370 | if __name__ == "__main__":
371 | main()
372 |
--------------------------------------------------------------------------------
/src/evaluation/evaluate/rouge.py:
--------------------------------------------------------------------------------
1 | """
2 | Borrowed from https://github.com/tensorflow/nmt/blob/master/nmt/scripts/rouge.py
3 |
4 | ROUGE metric implementation.
5 |
6 | Copy from tf_seq2seq/seq2seq/metrics/rouge.py.
7 | This is a modified and slightly extended verison of
8 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py.
9 | """
10 |
11 | from __future__ import absolute_import
12 | from __future__ import division
13 | from __future__ import print_function
14 | from __future__ import unicode_literals
15 |
16 | import itertools
17 | import numpy as np
18 |
19 | #pylint: disable=C0103
20 |
21 |
22 | def _get_ngrams(n, text):
23 | """Calcualtes n-grams.
24 |
25 | Args:
26 | n: which n-grams to calculate
27 | text: An array of tokens
28 |
29 | Returns:
30 | A set of n-grams
31 | """
32 | ngram_set = set()
33 | text_length = len(text)
34 | max_index_ngram_start = text_length - n
35 | for i in range(max_index_ngram_start + 1):
36 | ngram_set.add(tuple(text[i:i + n]))
37 | return ngram_set
38 |
39 |
40 | def _split_into_words(sentences):
41 | """Splits multiple sentences into words and flattens the result"""
42 | return list(itertools.chain(*[_.split(" ") for _ in sentences]))
43 |
44 |
45 | def _get_word_ngrams(n, sentences):
46 | """Calculates word n-grams for multiple sentences.
47 | """
48 | assert len(sentences) > 0
49 | assert n > 0
50 |
51 | words = _split_into_words(sentences)
52 | return _get_ngrams(n, words)
53 |
54 |
55 | def _len_lcs(x, y):
56 | """
57 | Returns the length of the Longest Common Subsequence between sequences x
58 | and y.
59 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
60 |
61 | Args:
62 | x: sequence of words
63 | y: sequence of words
64 |
65 | Returns
66 | integer: Length of LCS between x and y
67 | """
68 | table = _lcs(x, y)
69 | n, m = len(x), len(y)
70 | return table[n, m]
71 |
72 |
73 | def _lcs(x, y):
74 | """
75 | Computes the length of the longest common subsequence (lcs) between two
76 | strings. The implementation below uses a DP programming algorithm and runs
77 | in O(nm) time where n = len(x) and m = len(y).
78 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
79 |
80 | Args:
81 | x: collection of words
82 | y: collection of words
83 |
84 | Returns:
85 | Table of dictionary of coord and len lcs
86 | """
87 | n, m = len(x), len(y)
88 | table = dict()
89 | for i in range(n + 1):
90 | for j in range(m + 1):
91 | if i == 0 or j == 0:
92 | table[i, j] = 0
93 | elif x[i - 1] == y[j - 1]:
94 | table[i, j] = table[i - 1, j - 1] + 1
95 | else:
96 | table[i, j] = max(table[i - 1, j], table[i, j - 1])
97 | return table
98 |
99 |
100 | def _recon_lcs(x, y):
101 | """
102 | Returns the Longest Subsequence between x and y.
103 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
104 |
105 | Args:
106 | x: sequence of words
107 | y: sequence of words
108 |
109 | Returns:
110 | sequence: LCS of x and y
111 | """
112 | i, j = len(x), len(y)
113 | table = _lcs(x, y)
114 |
115 | def _recon(i, j):
116 | """private recon calculation"""
117 | if i == 0 or j == 0:
118 | return []
119 | elif x[i - 1] == y[j - 1]:
120 | return _recon(i - 1, j - 1) + [(x[i - 1], i)]
121 | elif table[i - 1, j] > table[i, j - 1]:
122 | return _recon(i - 1, j)
123 | else:
124 | return _recon(i, j - 1)
125 |
126 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j)))
127 | return recon_tuple
128 |
129 |
130 | def rouge_n(evaluated_sentences, reference_sentences, n=2):
131 | """
132 | Computes ROUGE-N of two text collections of sentences.
133 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/
134 | papers/rouge-working-note-v1.3.1.pdf
135 |
136 | Args:
137 | evaluated_sentences: The sentences that have been picked by the summarizer
138 | reference_sentences: The sentences from the referene set
139 | n: Size of ngram. Defaults to 2.
140 |
141 | Returns:
142 | A tuple (f1, precision, recall) for ROUGE-N
143 |
144 | Raises:
145 | ValueError: raises exception if a param has len <= 0
146 | """
147 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0:
148 | raise ValueError("Collections must contain at least 1 sentence.")
149 |
150 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences)
151 | reference_ngrams = _get_word_ngrams(n, reference_sentences)
152 | reference_count = len(reference_ngrams)
153 | evaluated_count = len(evaluated_ngrams)
154 |
155 | # Gets the overlapping ngrams between evaluated and reference
156 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
157 | overlapping_count = len(overlapping_ngrams)
158 |
159 | # Handle edge case. This isn't mathematically correct, but it's good enough
160 | if evaluated_count == 0:
161 | precision = 0.0
162 | else:
163 | precision = overlapping_count / evaluated_count
164 |
165 | if reference_count == 0:
166 | recall = 0.0
167 | else:
168 | recall = overlapping_count / reference_count
169 |
170 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))
171 |
172 | # return overlapping_count / reference_count
173 | return f1_score, precision, recall
174 |
175 |
176 | def _f_p_r_lcs(llcs, m, n):
177 | """
178 | Computes the LCS-based F-measure score
179 | Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/
180 | rouge-working-note-v1.3.1.pdf
181 |
182 | Args:
183 | llcs: Length of LCS
184 | m: number of words in reference summary
185 | n: number of words in candidate summary
186 |
187 | Returns:
188 | Float. LCS-based F-measure score
189 | """
190 | r_lcs = llcs / m
191 | p_lcs = llcs / n
192 | beta = p_lcs / (r_lcs + 1e-12)
193 | num = (1 + (beta**2)) * r_lcs * p_lcs
194 | denom = r_lcs + ((beta**2) * p_lcs)
195 | f_lcs = num / (denom + 1e-12)
196 | return f_lcs, p_lcs, r_lcs
197 |
198 |
199 | def rouge_l_sentence_level(evaluated_sentences, reference_sentences):
200 | """
201 | Computes ROUGE-L (sentence level) of two text collections of sentences.
202 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/
203 | rouge-working-note-v1.3.1.pdf
204 |
205 | Calculated according to:
206 | R_lcs = LCS(X,Y)/m
207 | P_lcs = LCS(X,Y)/n
208 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs)
209 |
210 | where:
211 | X = reference summary
212 | Y = Candidate summary
213 | m = length of reference summary
214 | n = length of candidate summary
215 |
216 | Args:
217 | evaluated_sentences: The sentences that have been picked by the summarizer
218 | reference_sentences: The sentences from the referene set
219 |
220 | Returns:
221 | A float: F_lcs
222 |
223 | Raises:
224 | ValueError: raises exception if a param has len <= 0
225 | """
226 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0:
227 | raise ValueError("Collections must contain at least 1 sentence.")
228 | reference_words = _split_into_words(reference_sentences)
229 | evaluated_words = _split_into_words(evaluated_sentences)
230 | m = len(reference_words)
231 | n = len(evaluated_words)
232 | lcs = _len_lcs(evaluated_words, reference_words)
233 | return _f_p_r_lcs(lcs, m, n)
234 |
235 |
236 | def _union_lcs(evaluated_sentences, reference_sentence):
237 | """
238 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common
239 | subsequence between reference sentence ri and candidate summary C. For example
240 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 and
241 | c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 is
242 | "w1 w2" and the longest common subsequence of r_i and c2 is "w1 w3 w5". The
243 | union longest common subsequence of r_i, c1, and c2 is "w1 w2 w3 w5" and
244 | LCS_u(r_i, C) = 4/5.
245 |
246 | Args:
247 | evaluated_sentences: The sentences that have been picked by the summarizer
248 | reference_sentence: One of the sentences in the reference summaries
249 |
250 | Returns:
251 | float: LCS_u(r_i, C)
252 |
253 | ValueError:
254 | Raises exception if a param has len <= 0
255 | """
256 | if len(evaluated_sentences) <= 0:
257 | raise ValueError("Collections must contain at least 1 sentence.")
258 |
259 | lcs_union = set()
260 | reference_words = _split_into_words([reference_sentence])
261 | combined_lcs_length = 0
262 | for eval_s in evaluated_sentences:
263 | evaluated_words = _split_into_words([eval_s])
264 | lcs = set(_recon_lcs(reference_words, evaluated_words))
265 | combined_lcs_length += len(lcs)
266 | lcs_union = lcs_union.union(lcs)
267 |
268 | union_lcs_count = len(lcs_union)
269 | union_lcs_value = union_lcs_count / combined_lcs_length
270 | return union_lcs_value
271 |
272 |
273 | def rouge_l_summary_level(evaluated_sentences, reference_sentences):
274 | """
275 | Computes ROUGE-L (summary level) of two text collections of sentences.
276 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/
277 | rouge-working-note-v1.3.1.pdf
278 |
279 | Calculated according to:
280 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m
281 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n
282 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs)
283 |
284 | where:
285 | SUM(i,u) = SUM from i through u
286 | u = number of sentences in reference summary
287 | C = Candidate summary made up of v sentences
288 | m = number of words in reference summary
289 | n = number of words in candidate summary
290 |
291 | Args:
292 | evaluated_sentences: The sentences that have been picked by the summarizer
293 | reference_sentence: One of the sentences in the reference summaries
294 |
295 | Returns:
296 | A float: F_lcs
297 |
298 | Raises:
299 | ValueError: raises exception if a param has len <= 0
300 | """
301 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0:
302 | raise ValueError("Collections must contain at least 1 sentence.")
303 |
304 | # total number of words in reference sentences
305 | m = len(_split_into_words(reference_sentences))
306 |
307 | # total number of words in evaluated sentences
308 | n = len(_split_into_words(evaluated_sentences))
309 |
310 | union_lcs_sum_across_all_references = 0
311 | for ref_s in reference_sentences:
312 | union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences,
313 | ref_s)
314 | return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n)
315 |
316 |
317 | def rouge(hypotheses, references):
318 | """Calculates average rouge scores for a list of hypotheses and
319 | references"""
320 |
321 | # Filter out hyps that are of 0 length
322 | # hyps_and_refs = zip(hypotheses, references)
323 | # hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0]
324 | # hypotheses, references = zip(*hyps_and_refs)
325 |
326 | # Calculate ROUGE-1 F1, precision, recall scores
327 | rouge_1 = [
328 | rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references)
329 | ]
330 | rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1))
331 |
332 | # Calculate ROUGE-2 F1, precision, recall scores
333 | rouge_2 = [
334 | rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references)
335 | ]
336 | rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2))
337 |
338 | # Calculate ROUGE-L F1, precision, recall scores
339 | rouge_l = [
340 | rouge_l_sentence_level([hyp], [ref])
341 | for hyp, ref in zip(hypotheses, references)
342 | ]
343 | rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l))
344 |
345 | return {
346 | "rouge_1/f_score": rouge_1_f,
347 | "rouge_1/r_score": rouge_1_r,
348 | "rouge_1/p_score": rouge_1_p,
349 | "rouge_2/f_score": rouge_2_f,
350 | "rouge_2/r_score": rouge_2_r,
351 | "rouge_2/p_score": rouge_2_p,
352 | "rouge_l/f_score": rouge_l_f,
353 | "rouge_l/r_score": rouge_l_r,
354 | "rouge_l/p_score": rouge_l_p,
355 | }
356 |
--------------------------------------------------------------------------------
/src/evaluation/evaluate/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import heapq
5 | import random
6 | import pickle
7 | import datetime
8 | from .rouge import rouge
9 | from .bleu import compute_bleu
10 |
11 |
12 | def rouge_score(references, generated):
13 | """both are a list of strings"""
14 | score = rouge(generated, references)
15 | rouge_s = {k: (v * 100) for (k, v) in score.items()}
16 | '''
17 | "rouge_1/f_score": rouge_1_f,
18 | "rouge_1/r_score": rouge_1_r,
19 | "rouge_1/p_score": rouge_1_p,
20 | "rouge_2/f_score": rouge_2_f,
21 | "rouge_2/r_score": rouge_2_r,
22 | "rouge_2/p_score": rouge_2_p,
23 | "rouge_l/f_score": rouge_l_f,
24 | "rouge_l/r_score": rouge_l_r,
25 | "rouge_l/p_score": rouge_l_p,
26 | '''
27 | return rouge_s
28 |
29 |
30 | def bleu_score(references, generated, n_gram=4, smooth=False):
31 | """a list of lists of tokens"""
32 | formatted_ref = [[ref] for ref in references]
33 | bleu_s, _, _, _, _, _ = compute_bleu(formatted_ref, generated, n_gram, smooth)
34 | return bleu_s * 100
35 |
36 |
37 | def two_seq_same(sa, sb):
38 | if len(sa) != len(sb):
39 | return False
40 | for (wa, wb) in zip(sa, sb):
41 | if wa != wb:
42 | return False
43 | return True
44 |
45 |
46 | def unique_sentence_percent(sequence_batch):
47 | unique_seq = []
48 | for seq in sequence_batch:
49 | count = 0
50 | for uni_seq in unique_seq:
51 | if two_seq_same(seq, uni_seq):
52 | count += 1
53 | break
54 | if count == 0:
55 | unique_seq.append(seq)
56 |
57 | return len(unique_seq) / len(sequence_batch), len(unique_seq)
58 |
59 |
60 | def feature_detect(seq_batch, feature_set):
61 | feature_batch = []
62 | for ids in seq_batch:
63 | feature_list = []
64 | for i in ids:
65 | if i in feature_set:
66 | feature_list.append(i)
67 | feature_batch.append(set(feature_list))
68 |
69 | return feature_batch
70 |
71 |
72 | def feature_matching_ratio(feature_batch, test_feature):
73 | count = 0
74 | for (fea_set, fea) in zip(feature_batch, test_feature):
75 | if fea in fea_set:
76 | count += 1
77 |
78 | return count / len(feature_batch)
79 |
80 |
81 | def feature_coverage_ratio(feature_batch, feature_set):
82 | features = set()
83 | for fb in feature_batch:
84 | features = features | fb
85 |
86 | return len(features) / len(feature_set)
87 |
88 |
89 | def feature_diversity(feature_batch):
90 | list_len = len(feature_batch)
91 |
92 | total_count = 0
93 | for i, x in enumerate(feature_batch):
94 | for j in range(i + 1, list_len):
95 | y = feature_batch[j]
96 | total_count += len(x & y)
97 |
98 | denominator = list_len * (list_len - 1) / 2
99 | return total_count / denominator
100 |
101 |
102 | def mean_absolute_error(predicted, max_r, min_r, mae=True):
103 | total = 0
104 | for (r, p) in predicted:
105 | if p > max_r:
106 | p = max_r
107 | if p < min_r:
108 | p = min_r
109 |
110 | sub = p - r
111 | if mae:
112 | total += abs(sub)
113 | else:
114 | total += sub ** 2
115 |
116 | return total / len(predicted)
117 |
118 |
119 | def root_mean_square_error(predicted, max_r, min_r):
120 | mse = mean_absolute_error(predicted, max_r, min_r, False)
121 | return math.sqrt(mse)
122 |
123 |
124 | class WordDictionary:
125 | def __init__(self):
126 | self.idx2word = ['', '', '', '']
127 | self.__predefine_num = len(self.idx2word)
128 | self.word2idx = {w: i for i, w in enumerate(self.idx2word)}
129 | self.__word2count = {}
130 |
131 | def add_sentence(self, sentence):
132 | for w in sentence.split():
133 | self.add_word(w)
134 |
135 | def add_word(self, w):
136 | if w not in self.word2idx:
137 | self.word2idx[w] = len(self.idx2word)
138 | self.idx2word.append(w)
139 | self.__word2count[w] = 1
140 | else:
141 | self.__word2count[w] += 1
142 |
143 | def __len__(self):
144 | return len(self.idx2word)
145 |
146 | def keep_most_frequent(self, max_vocab_size=20000):
147 | if len(self.__word2count) > max_vocab_size:
148 | frequent_words = heapq.nlargest(max_vocab_size, self.__word2count, key=self.__word2count.get)
149 | self.idx2word = self.idx2word[:self.__predefine_num] + frequent_words
150 | self.word2idx = {w: i for i, w in enumerate(self.idx2word)}
151 |
152 |
153 | class EntityDictionary:
154 | def __init__(self):
155 | self.idx2entity = []
156 | self.entity2idx = {}
157 |
158 | def add_entity(self, e):
159 | if e not in self.entity2idx:
160 | self.entity2idx[e] = len(self.idx2entity)
161 | self.idx2entity.append(e)
162 |
163 | def __len__(self):
164 | return len(self.idx2entity)
165 |
166 |
167 | class DataLoader:
168 | def __init__(self, data_path, index_dir, vocab_size):
169 | self.word_dict = WordDictionary()
170 | self.user_dict = EntityDictionary()
171 | self.item_dict = EntityDictionary()
172 | self.max_rating = float('-inf')
173 | self.min_rating = float('inf')
174 | self.initialize(data_path)
175 | self.word_dict.keep_most_frequent(vocab_size)
176 | self.__unk = self.word_dict.word2idx['']
177 | self.feature_set = set()
178 | self.train, self.valid, self.test = self.load_data(data_path, index_dir)
179 |
180 | def initialize(self, data_path):
181 | assert os.path.exists(data_path)
182 | reviews = pickle.load(open(data_path, 'rb'))
183 | for review in reviews:
184 | self.user_dict.add_entity(review['user'])
185 | self.item_dict.add_entity(review['item'])
186 | (fea, adj, tem, sco) = review['template']
187 | self.word_dict.add_sentence(tem)
188 | self.word_dict.add_word(fea)
189 | rating = review['rating']
190 | if self.max_rating < rating:
191 | self.max_rating = rating
192 | if self.min_rating > rating:
193 | self.min_rating = rating
194 |
195 | def load_data(self, data_path, index_dir):
196 | data = []
197 | reviews = pickle.load(open(data_path, 'rb'))
198 | for review in reviews:
199 | (fea, adj, tem, sco) = review['template']
200 | data.append({'user': self.user_dict.entity2idx[review['user']],
201 | 'item': self.item_dict.entity2idx[review['item']],
202 | 'rating': review['rating'],
203 | 'text': self.seq2ids(tem),
204 | 'feature': self.word_dict.word2idx.get(fea, self.__unk)})
205 | if fea in self.word_dict.word2idx:
206 | self.feature_set.add(fea)
207 | else:
208 | self.feature_set.add('')
209 |
210 | train_index, valid_index, test_index = self.load_index(index_dir)
211 | train, valid, test = [], [], []
212 | for idx in train_index:
213 | train.append(data[idx])
214 | for idx in valid_index:
215 | valid.append(data[idx])
216 | for idx in test_index:
217 | test.append(data[idx])
218 | return train, valid, test
219 |
220 | def seq2ids(self, seq):
221 | return [self.word_dict.word2idx.get(w, self.__unk) for w in seq.split()]
222 |
223 | def load_index(self, index_dir):
224 | assert os.path.exists(index_dir)
225 | with open(os.path.join(index_dir, 'train.index'), 'r') as f:
226 | train_index = [int(x) for x in f.readline().split(' ')]
227 | with open(os.path.join(index_dir, 'validation.index'), 'r') as f:
228 | valid_index = [int(x) for x in f.readline().split(' ')]
229 | with open(os.path.join(index_dir, 'test.index'), 'r') as f:
230 | test_index = [int(x) for x in f.readline().split(' ')]
231 | return train_index, valid_index, test_index
232 |
233 |
234 | def sentence_format(sentence, max_len, pad, bos, eos):
235 | length = len(sentence)
236 | if length >= max_len:
237 | return [bos] + sentence[:max_len] + [eos]
238 | else:
239 | return [bos] + sentence + [eos] + [pad] * (max_len - length)
240 |
241 |
242 | class Batchify:
243 | def __init__(self, data, word2idx, seq_len=15, batch_size=128, shuffle=False):
244 | bos = word2idx['']
245 | eos = word2idx['']
246 | pad = word2idx['']
247 | u, i, r, t, f = [], [], [], [], []
248 | for x in data:
249 | u.append(x['user'])
250 | i.append(x['item'])
251 | r.append(x['rating'])
252 | t.append(sentence_format(x['text'], seq_len, pad, bos, eos))
253 | f.append([x['feature']])
254 |
255 | self.user = torch.tensor(u, dtype=torch.int64).contiguous()
256 | self.item = torch.tensor(i, dtype=torch.int64).contiguous()
257 | self.rating = torch.tensor(r, dtype=torch.float).contiguous()
258 | self.seq = torch.tensor(t, dtype=torch.int64).contiguous()
259 | self.feature = torch.tensor(f, dtype=torch.int64).contiguous()
260 | self.shuffle = shuffle
261 | self.batch_size = batch_size
262 | self.sample_num = len(data)
263 | self.index_list = list(range(self.sample_num))
264 | self.total_step = int(math.ceil(self.sample_num / self.batch_size))
265 | self.step = 0
266 |
267 | def next_batch(self):
268 | if self.step == self.total_step:
269 | self.step = 0
270 | if self.shuffle:
271 | random.shuffle(self.index_list)
272 |
273 | start = self.step * self.batch_size
274 | offset = min(start + self.batch_size, self.sample_num)
275 | self.step += 1
276 | index = self.index_list[start:offset]
277 | user = self.user[index] # (batch_size,)
278 | item = self.item[index]
279 | rating = self.rating[index]
280 | seq = self.seq[index] # (batch_size, seq_len)
281 | feature = self.feature[index] # (batch_size, 1)
282 | return user, item, rating, seq, feature
283 |
284 |
285 | def now_time():
286 | return '[' + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') + ']: '
287 |
288 |
289 | def ids2tokens(ids, word2idx, idx2word):
290 | eos = word2idx['']
291 | tokens = []
292 | for i in ids:
293 | if i == eos:
294 | break
295 | tokens.append(idx2word[i])
296 | return tokens
297 |
--------------------------------------------------------------------------------
/src/evaluation/evaluate_chatglm_result.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # --------------------------------------------
4 | # @FileName: evaluate_chatglm_result.py
5 | # @Author: ljl
6 | # @Time: 2023/5/10
7 | # @Description:
8 | # --------------------------------------------
9 |
10 | import os
11 | import re
12 | from transformers import AutoTokenizer, AutoModel
13 | import argparse
14 | import pandas as pd
15 | from tqdm import tqdm
16 |
17 | template_multi = "假设你是一位医疗行业专家,请回答下列问题。注意,该问题是多选题\n" \
18 | "{}:\n{}\n" \
19 | "注意,请给出两行,第一行只需要返回答案的英文选项,第二行进行简要的解释。输出格式限制为“答案:”,“解释:”"
20 |
21 | template_single = "返回限制:只返回两行。" \
22 | "假设你是一位医疗行业专家,请回答下列问题,注意是单选题,只需要返回一个最合适的选项。\n" \
23 | "{}:\n{}\n" \
24 | "注意,结果只有两行,第一行只需要返回答案的英文选项(注意只需要返回一个最合适的答案),第二行进行简要的解释。输出格式限制为:“答案:”,“解释:”。\n" \
25 | "注意,题目是单选题,若有多个合适的答案,只返回最准确的即可。"
26 |
27 | def prediction(args):
28 | # load model
29 | tokenizer = AutoTokenizer.from_pretrained(args.modelpath, trust_remote_code=True)
30 | model = AutoModel.from_pretrained(args.tokenizerpath, trust_remote_code=True).half().cuda()
31 | model = model.eval()
32 |
33 | def predict(data):
34 | results = []
35 | for content in tqdm(data):
36 | try:
37 | response, history = model.chat(tokenizer, content, history=[])
38 | except Exception as e:
39 | response = ""
40 | results.append(response)
41 | return results
42 |
43 | # load csv
44 | csv = pd.read_csv(args.filepath)
45 | questions = csv['Question'].values.tolist()
46 | options = csv['Options'].values.tolist()
47 | gt_answer = csv['Answer'].values.tolist()
48 |
49 | data = []
50 | raw_results = []
51 | for i in range(len(questions)):
52 | if len(gt_answer[i]) == 1:
53 | data.append(template_single.format(questions[i], options[i]))
54 | else:
55 | data.append(template_multi.format(questions[i], options[i]))
56 |
57 | raw_results.extend(predict(data))
58 | predicted_answer = []
59 | predicted_explanation = []
60 | for single in raw_results:
61 | try:
62 | answer = re.findall(r"答案:(.*),", single)[0]
63 | exp = re.findall(r"解释:(.*)", single)[0]
64 | predicted_answer.append(answer)
65 | predicted_explanation.append(exp)
66 | except Exception as e:
67 | print(single, flush=True)
68 | predicted_answer.append("")
69 | predicted_explanation.append("")
70 |
71 | csv['raw_prediction'] = raw_results
72 | csv['predicted_answer'] = predicted_answer
73 | csv['predicted_explanation'] = predicted_explanation
74 |
75 | if not os.path.exists(args.savepath):
76 | os.mkdir(args.savepath)
77 | csv.to_csv(args.savepath, index=False)
78 |
79 |
80 | def evaluation(args):
81 | csv = pd.read_csv(args.savepath)
82 |
83 | gt_exp = csv['Explanation'].values.tolist()
84 | predict_exp = csv['predicted_explanation'].values.tolist()
85 | # process pd.na
86 | gt_exp = [item if not pd.isna(item) else "" for item in gt_exp]
87 | predict_exp = [item if not pd.isna(item) else "" for item in predict_exp]
88 |
89 | gt_answer = csv['Answer'].values.tolist()
90 | predict_answer = csv['predicted_answer'].values.tolist()
91 | gt_answer_with_value = []
92 | predict_answer_with_value = []
93 |
94 | total = 0.0
95 | correct = 0.0
96 | for i in range(len(gt_answer)):
97 | if not pd.isna(predict_answer[i]):
98 | total += 1
99 | gt_answer_with_value.append(gt_answer[i])
100 | predict_answer_with_value.append(predict_answer[i])
101 | if gt_answer[i] == predict_answer[i]:
102 | correct += 1
103 |
104 | gt_answer = gt_answer_with_value
105 | predict_answer = predict_answer_with_value
106 |
107 | print(total)
108 | print(correct / total)
109 |
110 | from sklearn.metrics import precision_recall_fscore_support
111 | precison, recall, fscore, _ = precision_recall_fscore_support(gt_answer, predict_answer, average='weighted')
112 | print('Precision: ', precison)
113 | print('Recall: ', recall)
114 | print('Fscore: ', fscore)
115 |
116 | from evaluate.utils import rouge_score, bleu_score, unique_sentence_percent, root_mean_square_error, \
117 | mean_absolute_error, feature_detect, feature_matching_ratio, feature_coverage_ratio, feature_diversity
118 |
119 | tokens_of_processed_predict_exps = [list(jieba.cut(item, cut_all=False)) for item in predict_exp]
120 | tokens_of_processed_gt_exps = [list(jieba.cut(item, cut_all=False)) for item in gt_exp]
121 | # tokens_of_processed_predict_exps = [list(item) for item in predict_exp]
122 | # tokens_of_processed_gt_exps = [list(item) for item in gt_exp]
123 |
124 | processed_gt_exps = [' '.join(list(item)) for item in predict_exp]
125 | processed_predict_exps = [' '.join(list(item)) for item in gt_exp]
126 |
127 | BLEU1 = bleu_score(tokens_of_processed_gt_exps, tokens_of_processed_predict_exps, n_gram=1, smooth=False)
128 | BLEU2 = bleu_score(tokens_of_processed_gt_exps, tokens_of_processed_predict_exps, n_gram=2, smooth=False)
129 | BLEU4 = bleu_score(tokens_of_processed_gt_exps, tokens_of_processed_predict_exps, n_gram=4, smooth=False)
130 | ROUGE = rouge_score(processed_gt_exps, processed_predict_exps)
131 |
132 | print('BLEU-1 {:7.4f}'.format(BLEU1))
133 | print('BLEU-2 {:7.4f}'.format(BLEU2))
134 | print('BLEU-4 {:7.4f}'.format(BLEU4))
135 | for (k, v) in ROUGE.items():
136 | print('{} {:7.4f}'.format(k, v))
137 |
138 |
139 | if __name__ == '__main__':
140 | parser = argparse.ArgumentParser()
141 | parser.add_argument("--filepath", type=str, default="../../data/test_with_annotations.csv")
142 | parser.add_argument("--savepath", type=str, default="../exp/test_with_chatglm.csv")
143 | parser.add_argument("--modelpath", type=str, default="THUDM/chatglm-6b")
144 | parser.add_argument("--tokenizerpath", type=str, default="THUDM/chatglm-6b")
145 | args = parser.parse_args()
146 | prediction(args)
147 | evaluation(args)
--------------------------------------------------------------------------------
/src/evaluation/evaluate_ft_result.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # --------------------------------------------
4 | # @FileName: calc_metrics.py
5 | # @Author: ljl
6 | # @Time: 2023/5/10
7 | # @Description:
8 | # --------------------------------------------
9 |
10 | import pandas as pd
11 | import jieba
12 |
13 | filepath = 'test_predicted.csv'
14 |
15 | csv = pd.read_csv(filepath)
16 |
17 | gt_exp = csv['Explanation'].values.tolist()
18 | predict_exp = csv['explanation'].values.tolist()
19 | # process pd.na
20 | gt_exp = [item if not pd.isna(item) else "" for item in gt_exp]
21 | predict_exp = [item if not pd.isna(item) else "" for item in predict_exp]
22 |
23 | # gt_answer = csv['Answer'].values.tolist()
24 | # predict_answer = csv['answer_prediction'].values.tolist()
25 | # gt_answer_with_value = []
26 | # predict_answer_with_value = []
27 | #
28 | # total = 0.0
29 | # correct = 0.0
30 | # for i in range(len(gt_answer)):
31 | # if not pd.isna(predict_answer[i]):
32 | # total+=1
33 | # gt_answer_with_value.append(gt_answer[i])
34 | # predict_answer_with_value.append(predict_answer[i])
35 | # if gt_answer[i] == predict_answer[i]:
36 | # correct+=1
37 | #
38 | #
39 | # gt_answer = gt_answer_with_value
40 | # predict_answer = predict_answer_with_value
41 | #
42 | # print(total)
43 | # print(correct/total)
44 |
45 | from sklearn.metrics import precision_recall_fscore_support
46 | precison, recall, fscore, _ = precision_recall_fscore_support(gt_answer, predict_answer, average='weighted')
47 | print('Precision: ', precison)
48 | print('Recall: ', recall)
49 | print('Fscore: ', fscore)
50 |
51 | from src.evaluation.evaluate.utils import rouge_score, bleu_score, unique_sentence_percent, root_mean_square_error, mean_absolute_error, feature_detect, feature_matching_ratio, feature_coverage_ratio, feature_diversity
52 |
53 | tokens_of_processed_predict_exps = [list(jieba.cut(item,cut_all=False)) for item in predict_exp]
54 | tokens_of_processed_gt_exps = [list(jieba.cut(item,cut_all=False)) for item in gt_exp]
55 |
56 | # tokens_of_processed_predict_exps = [list(item) for item in predict_exp]
57 | # tokens_of_processed_gt_exps = [list(item) for item in gt_exp]
58 | processed_gt_exps = [' '.join(list(item)) for item in predict_exp]
59 | processed_predict_exps = [' '.join(list(item)) for item in gt_exp]
60 |
61 | BLEU1 = bleu_score(tokens_of_processed_gt_exps, tokens_of_processed_predict_exps, n_gram=1, smooth=False)
62 | BLEU2 = bleu_score(tokens_of_processed_gt_exps, tokens_of_processed_predict_exps, n_gram=2, smooth=False)
63 | BLEU4 = bleu_score(tokens_of_processed_gt_exps, tokens_of_processed_predict_exps, n_gram=4, smooth=False)
64 | ROUGE = rouge_score(processed_gt_exps, processed_predict_exps)
65 |
66 | print('BLEU-1 {:7.4f}'.format(BLEU1))
67 | print('BLEU-2 {:7.4f}'.format(BLEU2))
68 | print('BLEU-4 {:7.4f}'.format(BLEU4))
69 | for (k, v) in ROUGE.items():
70 | print('{} {:7.4f}'.format(k, v))
71 |
--------------------------------------------------------------------------------
/src/evaluation/evaluate_gpt_result.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # --------------------------------------------
4 | # @FileName: translate.py
5 | # @Author: ljl
6 | # @Time: 2023/5/4
7 | # @Description:
8 | # --------------------------------------------
9 |
10 | import openai
11 | import argparse
12 | import os
13 | import time
14 | import jieba
15 | from multiprocessing import Pool
16 | import pandas as pd
17 | from tqdm import tqdm
18 | os.environ["HTTP_PROXY"] = "socks5h://127.0.0.1:13659"
19 | os.environ["HTTPS_PROXY"] = "socks5h://127.0.0.1:13659"
20 | # os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890"
21 | # os.environ["HTTPS_PROXY"] = "https://127.0.0.1:7890"
22 |
23 | def call_api(data,question_nums, model):
24 | results = []
25 | try:
26 | for i, content in tqdm(enumerate(data)):
27 | result = ""
28 | try:
29 | completion = openai.ChatCompletion.create(
30 | model=model,
31 | # model="gpt-4",
32 | # model="gpt-4-0314",
33 | messages=[{"role": "user", "content": content}]
34 | )
35 | result = completion.choices[0].message.content
36 | except Exception as e:
37 | print(str(e), flush=True)
38 | results.append(result)
39 | except Exception as e:
40 | print(str(e), flush=True)
41 | results.extend(["[]" for _ in range(len(data)-len(results))])
42 | return results,question_nums
43 |
44 | def prediction(args):
45 | openai.api_key = args.api_key
46 |
47 | csv = pd.read_csv(args.filepath)
48 | questions = csv['Question'].values.tolist()
49 | options = csv['Options'].values.tolist()
50 |
51 | template = "返回格式为一个python列表,包含每道题的答案英文选项和解释 \n" \
52 | "假设你是一位医疗行业专家,请回答下列几个问题。\n" \
53 | "题目信息为:{} \n" \
54 | "注意,每个题目的回答以一个字符串保存,返回答案的英文选项,并进行简要的解释。字符串输出格式限制为“答案:**,解释:**”"
55 | data = []
56 | question_nums = []
57 | step = 5
58 |
59 | for i in range(0,len(questions),step):
60 | question_group = ""
61 | question_num = min(step, len(questions)-i)
62 | for j in range(question_num):
63 | question_group+="{}.题目信息为 {}:{}\n".format(str(j+1),questions[i+j], options[i+j].replace('\n',','))
64 |
65 | data.append(template.format(question_group))
66 | question_nums.append(question_num)
67 |
68 | # data = data[:2]
69 | # question_nums = question_nums[:2]
70 |
71 | # multiprocessing
72 | num_of_processes = 1
73 | pool = Pool(processes=num_of_processes)
74 | pool_results = []
75 | each_size = len(data) // num_of_processes
76 | for i in range(num_of_processes):
77 | if i0:
34 | return "".join(predict_ops)
35 | else:
36 | return "无答案"
37 | def parse_explanations(row):
38 | # 从'model_results'中提取答案部分(即选项)
39 | if not isinstance(row['model_result'],str):
40 | return "无答案"
41 | if '解释:' not in row['model_result']:
42 | original_result = row['model_result']
43 | else:
44 | original_result = row['model_result'].split('解释:')[1].strip()
45 | return original_result
46 | def evaluate_reasoning(df):
47 | def add_spaces(l):
48 | return [' '.join(list(_)) for _ in l]
49 | source = '答案解析'
50 | target = 'parsed_explanation'
51 | df.dropna(subset=[source, target], inplace=True)
52 | tokens_predict = df[target].to_list()
53 | tokens_test = df[source].to_list()
54 |
55 | tokens_predict = add_spaces(tokens_predict)
56 | tokens_test = add_spaces(tokens_test)
57 |
58 | new_tokens_predict = [l.split() for l in tokens_predict]
59 | new_tokens_test = [ll.split() for ll in tokens_test]
60 | BLEU1 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=1, smooth=False)
61 | BLEU4 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=4, smooth=False)
62 | ROUGE = rouge_score(tokens_test, tokens_predict)
63 |
64 | print('BLEU-1 {:7.4f}'.format(BLEU1))
65 | print('BLEU-4 {:7.4f}'.format(BLEU4))
66 | for (k, v) in ROUGE.items():
67 | if 'f_score' in k:
68 | print('{} {:7.4f}'.format(k, v))
69 |
70 | def evaluate_prediction(df):
71 | correct = df[df['parsed_option']==df['答案']].shape[0]
72 | total = df.shape[0]
73 | num_no_answer = df[df['parsed_option']=='无答案'].shape[0]
74 |
75 | processed_gts = df['答案'].to_list()
76 | processed_results = df['parsed_option'].to_list()
77 | precison, recall, fscore, _ = precision_recall_fscore_support(processed_gts, processed_results, average='weighted')
78 | print('Precision: ', precison)
79 | print('Recall: ', recall)
80 | print('Fscore: ', fscore)
81 | print('Acc:{}'.format(correct/total*100))
82 | print('The number of "No answers:"',num_no_answer)
83 |
84 | def main(
85 | csv_file_path: str = "../LoRA/output/medalpaca_4.csv",
86 | ):
87 | df = pd.read_csv(csv_file_path)
88 |
89 | df['parsed_option'] = df.apply(parse_options,axis=1)
90 | df['parsed_explanation'] = df.apply(parse_explanations,axis=1)
91 |
92 | print('Evaluation of prediction:')
93 | evaluate_prediction(df)
94 | print('*'*20)
95 | print('Evaluation of reasoning:')
96 | evaluate_reasoning(df)
97 |
98 | if __name__ == "__main__":
99 | fire.Fire(main)
--------------------------------------------------------------------------------
/src/preprocess/dataset_dist.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/williamliujl/CMExam/fadb22c89beb1b7115dc36460ba792eb96b7b972/src/preprocess/dataset_dist.pdf
--------------------------------------------------------------------------------
/src/preprocess/generate_prompt.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # --------------------------------------------
4 | # @FileName: generate_prompt.py
5 | # @Author: ljl
6 | # @Time: 2023/5/15
7 | # @Description:
8 | # --------------------------------------------
9 |
10 | import pandas as pd
11 | import json
12 | import copy
13 | import argparse
14 | from prompt_templates import all_task_templates
15 |
16 |
17 | def main(args):
18 |
19 | filepath = args.filepath
20 |
21 | csv = pd.read_csv(filepath)
22 |
23 | # prompt_templates = ["1","2","3","4","5","6"]
24 | prompt_templates = args.templates.split(",")
25 |
26 | prompts = []
27 |
28 | for i,data in enumerate(csv.values):
29 |
30 | question = data[csv.columns.values.tolist().index("Question")]
31 | options = data[csv.columns.values.tolist().index("Options")]
32 | explanation = data[csv.columns.values.tolist().index("Explanation")]
33 | option_lists = options.split("\n")
34 | answer = data[csv.columns.values.tolist().index("Answer")]
35 | if pd.isna(answer):
36 | continue
37 | answer_content = ""
38 | for option in option_lists:
39 | if option.split(" ")[0] == answer:
40 | answer_content = option.split(" ")[-1]
41 |
42 | for prompt_idx in prompt_templates:
43 | prompt_template = copy.deepcopy(all_task_templates[prompt_idx])
44 | try:
45 | if prompt_idx == "1":
46 | prompt_template["prompt"] = prompt_template["prompt"].format(question, options)
47 | prompt_template["completion"] = prompt_template["completion"].format(answer)
48 | prompts.append(prompt_template)
49 | elif prompt_idx == "2":
50 | prompt_template["prompt"] = prompt_template["prompt"].format(question, options)
51 | prompt_template["completion"] = prompt_template["completion"].format(answer+" "+ answer_content)
52 | prompts.append(prompt_template)
53 | elif prompt_idx == "3":
54 | prompt_template["prompt"] = prompt_template["prompt"].format(question, options)
55 | prompt_template["completion"] = prompt_template["completion"].format(explanation)
56 | prompts.append(prompt_template)
57 | elif prompt_idx == "4":
58 | prompt_template["prompt"] = prompt_template["prompt"].format(question, options)
59 | prompt_template["completion"] = prompt_template["completion"].format(answer+" "+ answer_content, explanation)
60 | prompts.append(prompt_template)
61 | elif prompt_idx == "5":
62 | prompt_template["prompt"] = prompt_template["prompt"].format(question)
63 | prompt_template["completion"] = prompt_template["completion"].format(answer_content)
64 | prompts.append(prompt_template)
65 | elif prompt_idx == "6":
66 | prompt_template["prompt"] = prompt_template["prompt"].format(question)
67 | prompt_template["completion"] = prompt_template["completion"].format(answer_content, explanation)
68 | prompts.append(prompt_template)
69 | except Exception as e:
70 | print(data)
71 |
72 | # save json
73 | savepath = filepath.replace(".csv", ".json")
74 | with open(savepath, 'w') as f:
75 | for prompt in prompts:
76 | json_file = {
77 | "prompt":prompt["prompt"],
78 | "completion":prompt["completion"],
79 | "id":prompt["id"]
80 | }
81 | json_str = json.dumps(json_file,ensure_ascii=False)
82 | f.write(json_str + '\n')
83 | f.close()
84 |
85 | # save csv
86 | savepath = filepath.replace(".csv", "_prompt.json")
87 | csv["prompt"] = [prompt["prompt"] for prompt in prompts]
88 | csv["completion"] = [prompt["completion"] for prompt in prompts]
89 | csv["id"] = [prompt["id"] for prompt in prompts]
90 | csv.to_csv(savepath)
91 |
92 | if __name__ == "__main__":
93 | parser = argparse.ArgumentParser()
94 | parser.add_argument("--filepath", type=str, required=True)
95 | parser.add_argument("--templates", type=str, default="1,2", help="To generate prompts using different templates")
96 | args = parser.parse_args()
97 | main(args)
98 |
--------------------------------------------------------------------------------
/src/preprocess/prompt_templates.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # --------------------------------------------
4 | # @FileName: prompt_templates.py
5 | # @Author: ljl
6 | # @Time: 2023/5/15
7 | # @Description:
8 | # --------------------------------------------
9 |
10 | all_task_templates = {}
11 |
12 | template = {}
13 | template['prompt'] = "问题: {}, \n 选项: {}"
14 | template['completion'] = "答案: {}"
15 | template['id'] = "1"
16 | all_task_templates["1"] = template
17 |
18 | template = {}
19 | template['prompt'] = "问题: {}, \n 选项: {}"
20 | template['completion'] = "答案: {}"
21 | template['id'] = "2"
22 | all_task_templates["2"] = template
23 |
24 | template = {}
25 | template['prompt'] = "问题: {}, \n 选项: {}"
26 | template['completion'] = "解释: {}"
27 | template['id'] = "3"
28 | all_task_templates["3"] = template
29 |
30 | template = {}
31 | template['prompt'] = "问题: {}, \n 选项: {}"
32 | template['completion'] = "答案: {}. \n 解释:{}"
33 | template['id'] = "4"
34 | all_task_templates["4"] = template
35 |
36 | template = {}
37 | template['prompt'] = "问题: {}"
38 | template['completion'] = "答案: {}"
39 | template['id'] = "5"
40 | all_task_templates["5"] = template
41 |
42 | template = {}
43 | template['prompt'] = "问题: {}"
44 | template['completion'] = "答案: {}. \n 解释: {}"
45 | template['id'] = "6"
46 | all_task_templates["6"] = template
--------------------------------------------------------------------------------
/src/ptuning/arguments.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 |
5 | @dataclass
6 | class ModelArguments:
7 | """
8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
9 | """
10 |
11 | model_name_or_path: str = field(
12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
13 | )
14 | ptuning_checkpoint: str = field(
15 | default=None, metadata={"help": "Path to p-tuning v2 checkpoints"}
16 | )
17 | config_name: Optional[str] = field(
18 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
19 | )
20 | tokenizer_name: Optional[str] = field(
21 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
22 | )
23 | cache_dir: Optional[str] = field(
24 | default=None,
25 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
26 | )
27 | use_fast_tokenizer: bool = field(
28 | default=True,
29 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
30 | )
31 | model_revision: str = field(
32 | default="main",
33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
34 | )
35 | use_auth_token: bool = field(
36 | default=False,
37 | metadata={
38 | "help": (
39 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
40 | "with private models)."
41 | )
42 | },
43 | )
44 | resize_position_embeddings: Optional[bool] = field(
45 | default=None,
46 | metadata={
47 | "help": (
48 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
49 | "the model's position embeddings."
50 | )
51 | },
52 | )
53 | quantization_bit: Optional[int] = field(
54 | default=None
55 | )
56 | pre_seq_len: Optional[int] = field(
57 | default=None
58 | )
59 | prefix_projection: bool = field(
60 | default=False
61 | )
62 |
63 |
64 | @dataclass
65 | class DataTrainingArguments:
66 | """
67 | Arguments pertaining to what data we are going to input our model for training and eval.
68 | """
69 |
70 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})
71 |
72 | dataset_name: Optional[str] = field(
73 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
74 | )
75 | dataset_config_name: Optional[str] = field(
76 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
77 | )
78 | prompt_column: Optional[str] = field(
79 | default=None,
80 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
81 | )
82 | response_column: Optional[str] = field(
83 | default=None,
84 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
85 | )
86 | history_column: Optional[str] = field(
87 | default=None,
88 | metadata={"help": "The name of the column in the datasets containing the history of chat."},
89 | )
90 | train_file: Optional[str] = field(
91 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
92 | )
93 | validation_file: Optional[str] = field(
94 | default=None,
95 | metadata={
96 | "help": (
97 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
98 | )
99 | },
100 | )
101 | test_file: Optional[str] = field(
102 | default=None,
103 | metadata={
104 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
105 | },
106 | )
107 | overwrite_cache: bool = field(
108 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
109 | )
110 | preprocessing_num_workers: Optional[int] = field(
111 | default=None,
112 | metadata={"help": "The number of processes to use for the preprocessing."},
113 | )
114 | max_source_length: Optional[int] = field(
115 | default=1024,
116 | metadata={
117 | "help": (
118 | "The maximum total input sequence length after tokenization. Sequences longer "
119 | "than this will be truncated, sequences shorter will be padded."
120 | )
121 | },
122 | )
123 | max_target_length: Optional[int] = field(
124 | default=128,
125 | metadata={
126 | "help": (
127 | "The maximum total sequence length for target text after tokenization. Sequences longer "
128 | "than this will be truncated, sequences shorter will be padded."
129 | )
130 | },
131 | )
132 | val_max_target_length: Optional[int] = field(
133 | default=None,
134 | metadata={
135 | "help": (
136 | "The maximum total sequence length for validation target text after tokenization. Sequences longer "
137 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
138 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
139 | "during ``evaluate`` and ``predict``."
140 | )
141 | },
142 | )
143 | pad_to_max_length: bool = field(
144 | default=False,
145 | metadata={
146 | "help": (
147 | "Whether to pad all samples to model maximum sentence length. "
148 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
149 | "efficient on GPU but very bad for TPU."
150 | )
151 | },
152 | )
153 | max_train_samples: Optional[int] = field(
154 | default=None,
155 | metadata={
156 | "help": (
157 | "For debugging purposes or quicker training, truncate the number of training examples to this "
158 | "value if set."
159 | )
160 | },
161 | )
162 | max_eval_samples: Optional[int] = field(
163 | default=None,
164 | metadata={
165 | "help": (
166 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
167 | "value if set."
168 | )
169 | },
170 | )
171 | max_predict_samples: Optional[int] = field(
172 | default=None,
173 | metadata={
174 | "help": (
175 | "For debugging purposes or quicker training, truncate the number of prediction examples to this "
176 | "value if set."
177 | )
178 | },
179 | )
180 | num_beams: Optional[int] = field(
181 | default=None,
182 | metadata={
183 | "help": (
184 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
185 | "which is used during ``evaluate`` and ``predict``."
186 | )
187 | },
188 | )
189 | ignore_pad_token_for_loss: bool = field(
190 | default=True,
191 | metadata={
192 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
193 | },
194 | )
195 | source_prefix: Optional[str] = field(
196 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
197 | )
198 |
199 | forced_bos_token: Optional[str] = field(
200 | default=None,
201 | metadata={
202 | "help": (
203 | "The token to force as the first generated token after the decoder_start_token_id."
204 | "Useful for multilingual models like mBART where the first generated token"
205 | "needs to be the target language token (Usually it is the target language token)"
206 | )
207 | },
208 | )
209 |
210 |
211 |
212 | def __post_init__(self):
213 | if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None:
214 | raise ValueError("Need either a dataset name or a training/validation/test file.")
215 | else:
216 | if self.train_file is not None:
217 | extension = self.train_file.split(".")[-1]
218 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
219 | if self.validation_file is not None:
220 | extension = self.validation_file.split(".")[-1]
221 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
222 | if self.val_max_target_length is None:
223 | self.val_max_target_length = self.max_target_length
224 |
225 |
--------------------------------------------------------------------------------
/src/ptuning/deepspeed.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": "auto",
3 | "zero_allow_untested_optimizer": true,
4 | "fp16": {
5 | "enabled": "auto",
6 | "loss_scale": 0,
7 | "initial_scale_power": 16,
8 | "loss_scale_window": 1000,
9 | "hysteresis": 2,
10 | "min_loss_scale": 1
11 | },
12 | "zero_optimization": {
13 | "stage": 2,
14 | "allgather_partitions": true,
15 | "allgather_bucket_size": 5e8,
16 | "overlap_comm": false,
17 | "reduce_scatter": true,
18 | "reduce_bucket_size": 5e8,
19 | "contiguous_gradients" : true
20 | }
21 | }
--------------------------------------------------------------------------------
/src/ptuning/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2021 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for sequence to sequence.
18 | """
19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20 |
21 | import logging
22 | import os
23 | import sys
24 | import json
25 |
26 | import numpy as np
27 | from datasets import load_dataset
28 | import jieba
29 | from rouge_chinese import Rouge
30 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
31 | import torch
32 |
33 | import transformers
34 | from transformers import (
35 | AutoConfig,
36 | AutoModel,
37 | AutoTokenizer,
38 | AutoTokenizer,
39 | DataCollatorForSeq2Seq,
40 | HfArgumentParser,
41 | Seq2SeqTrainingArguments,
42 | set_seed,
43 | )
44 | from trainer_seq2seq import Seq2SeqTrainer
45 |
46 | from arguments import ModelArguments, DataTrainingArguments
47 |
48 | logger = logging.getLogger(__name__)
49 |
50 | def main():
51 |
52 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
53 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
54 | # If we pass only one argument to the script and it's the path to a json file,
55 | # let's parse it to get our arguments.
56 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
57 | else:
58 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
59 |
60 | # Setup logging
61 | logging.basicConfig(
62 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
63 | datefmt="%m/%d/%Y %H:%M:%S",
64 | handlers=[logging.StreamHandler(sys.stdout)],
65 | )
66 |
67 | if training_args.should_log:
68 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
69 | transformers.utils.logging.set_verbosity_info()
70 |
71 | log_level = training_args.get_process_log_level()
72 | logger.setLevel(log_level)
73 | # datasets.utils.logging.set_verbosity(log_level)
74 | transformers.utils.logging.set_verbosity(log_level)
75 | transformers.utils.logging.enable_default_handler()
76 | transformers.utils.logging.enable_explicit_format()
77 |
78 | # Log on each process the small summary:
79 | logger.warning(
80 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
81 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
82 | )
83 | logger.info(f"Training/evaluation parameters {training_args}")
84 |
85 | # Set seed before initializing model.
86 | set_seed(training_args.seed)
87 |
88 | # Load dataset
89 | data_files = {}
90 | if data_args.train_file is not None:
91 | data_files["train"] = data_args.train_file
92 | extension = data_args.train_file.split(".")[-1]
93 | if data_args.validation_file is not None:
94 | data_files["validation"] = data_args.validation_file
95 | extension = data_args.validation_file.split(".")[-1]
96 | if data_args.test_file is not None:
97 | data_files["test"] = data_args.test_file
98 | extension = data_args.test_file.split(".")[-1]
99 |
100 | raw_datasets = load_dataset(
101 | extension,
102 | data_files=data_files,
103 | cache_dir=model_args.cache_dir,
104 | use_auth_token=True if model_args.use_auth_token else None,
105 | )
106 |
107 | # Load pretrained model and tokenizer
108 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
109 | config.pre_seq_len = model_args.pre_seq_len
110 | config.prefix_projection = model_args.prefix_projection
111 |
112 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
113 |
114 | if model_args.ptuning_checkpoint is not None:
115 | # Evaluation
116 | # Loading extra state dict of prefix encoder
117 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
118 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
119 | new_prefix_state_dict = {}
120 | for k, v in prefix_state_dict.items():
121 | if k.startswith("transformer.prefix_encoder."):
122 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
123 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
124 | else:
125 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
126 |
127 | if model_args.quantization_bit is not None:
128 | print(f"Quantized to {model_args.quantization_bit} bit")
129 | model = model.quantize(model_args.quantization_bit)
130 | if model_args.pre_seq_len is not None:
131 | # P-tuning v2
132 | model = model.half()
133 | model.transformer.prefix_encoder.float()
134 | else:
135 | # Finetune
136 | model = model.float()
137 |
138 | prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
139 |
140 | # Preprocessing the datasets.
141 | # We need to tokenize inputs and targets.
142 | if training_args.do_train:
143 | column_names = raw_datasets["train"].column_names
144 | elif training_args.do_eval:
145 | column_names = raw_datasets["validation"].column_names
146 | elif training_args.do_predict:
147 | column_names = raw_datasets["test"].column_names
148 | else:
149 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
150 | return
151 |
152 | # Get the column names for input/target.
153 | prompt_column = data_args.prompt_column
154 | response_column = data_args.response_column
155 | history_column = data_args.history_column
156 |
157 | # Temporarily set max_target_length for training.
158 | max_target_length = data_args.max_target_length
159 |
160 | def preprocess_function_eval(examples):
161 | inputs, targets = [], []
162 | for i in range(len(examples[prompt_column])):
163 | if examples[prompt_column][i] and examples[response_column][i]:
164 | query = examples[prompt_column][i]
165 | if history_column is None or len(examples[history_column][i]) == 0:
166 | prompt = query
167 | else:
168 | prompt = ""
169 | history = examples[history_column][i]
170 | for turn_idx, (old_query, response) in enumerate(history):
171 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
172 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
173 | inputs.append(prompt)
174 | targets.append(examples[response_column][i])
175 |
176 | inputs = [prefix + inp for inp in inputs]
177 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
178 | labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
179 |
180 | if data_args.ignore_pad_token_for_loss:
181 | labels["input_ids"] = [
182 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
183 | ]
184 | model_inputs["labels"] = labels["input_ids"]
185 |
186 | return model_inputs
187 |
188 | def preprocess_function_train(examples):
189 | max_seq_length = data_args.max_source_length + data_args.max_target_length
190 |
191 | model_inputs = {
192 | "input_ids": [],
193 | "labels": [],
194 | }
195 | for i in range(len(examples[prompt_column])):
196 | if examples[prompt_column][i] and examples[response_column][i]:
197 | query, answer = examples[prompt_column][i], examples[response_column][i]
198 |
199 | if history_column is None:
200 | prompt = query
201 | else:
202 | prompt = ""
203 | history = examples[history_column][i]
204 | for turn_idx, (old_query, response) in enumerate(history):
205 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
206 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
207 |
208 | prompt = prefix + prompt
209 | a_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
210 | b_ids = tokenizer.encode(text=answer, add_special_tokens=False)
211 |
212 | if len(a_ids) > data_args.max_source_length - 1:
213 | a_ids = a_ids[: data_args.max_source_length - 1]
214 |
215 | if len(b_ids) > data_args.max_target_length - 2:
216 | b_ids = b_ids[: data_args.max_target_length - 2]
217 |
218 | input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)
219 |
220 | # import pdb;pdb.set_trace()
221 | context_length = input_ids.index(tokenizer.bos_token_id)
222 | mask_position = context_length - 1
223 | labels = [-100] * context_length + input_ids[mask_position+1:]
224 |
225 | pad_len = max_seq_length - len(input_ids)
226 | input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
227 | labels = labels + [tokenizer.pad_token_id] * pad_len
228 | if data_args.ignore_pad_token_for_loss:
229 | labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
230 |
231 | model_inputs["input_ids"].append(input_ids)
232 | model_inputs["labels"].append(labels)
233 |
234 | return model_inputs
235 |
236 | def print_dataset_example(example):
237 | print("input_ids",example["input_ids"])
238 | print("inputs", tokenizer.decode(example["input_ids"]))
239 | print("label_ids", example["labels"])
240 | print("labels", tokenizer.decode(example["labels"]))
241 |
242 | if training_args.do_train:
243 | if "train" not in raw_datasets:
244 | raise ValueError("--do_train requires a train dataset")
245 | train_dataset = raw_datasets["train"]
246 | if data_args.max_train_samples is not None:
247 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
248 | train_dataset = train_dataset.select(range(max_train_samples))
249 | with training_args.main_process_first(desc="train dataset map pre-processing"):
250 | train_dataset = train_dataset.map(
251 | preprocess_function_train,
252 | batched=True,
253 | num_proc=data_args.preprocessing_num_workers,
254 | remove_columns=column_names,
255 | load_from_cache_file=not data_args.overwrite_cache,
256 | desc="Running tokenizer on train dataset",
257 | )
258 | print_dataset_example(train_dataset[0])
259 |
260 | if training_args.do_eval:
261 | max_target_length = data_args.val_max_target_length
262 | if "validation" not in raw_datasets:
263 | raise ValueError("--do_eval requires a validation dataset")
264 | eval_dataset = raw_datasets["validation"]
265 | if data_args.max_eval_samples is not None:
266 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
267 | eval_dataset = eval_dataset.select(range(max_eval_samples))
268 | with training_args.main_process_first(desc="validation dataset map pre-processing"):
269 | eval_dataset = eval_dataset.map(
270 | preprocess_function_eval,
271 | batched=True,
272 | num_proc=data_args.preprocessing_num_workers,
273 | remove_columns=column_names,
274 | load_from_cache_file=not data_args.overwrite_cache,
275 | desc="Running tokenizer on validation dataset",
276 | )
277 | print_dataset_example(eval_dataset[0])
278 |
279 | if training_args.do_predict:
280 | max_target_length = data_args.val_max_target_length
281 | if "test" not in raw_datasets:
282 | raise ValueError("--do_predict requires a test dataset")
283 | predict_dataset = raw_datasets["test"]
284 | if data_args.max_predict_samples is not None:
285 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
286 | predict_dataset = predict_dataset.select(range(max_predict_samples))
287 | with training_args.main_process_first(desc="prediction dataset map pre-processing"):
288 | predict_dataset = predict_dataset.map(
289 | preprocess_function_eval,
290 | batched=True,
291 | num_proc=data_args.preprocessing_num_workers,
292 | remove_columns=column_names,
293 | load_from_cache_file=not data_args.overwrite_cache,
294 | desc="Running tokenizer on prediction dataset",
295 | )
296 | print_dataset_example(predict_dataset[0])
297 |
298 | # Data collator
299 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
300 | data_collator = DataCollatorForSeq2Seq(
301 | tokenizer,
302 | model=model,
303 | label_pad_token_id=label_pad_token_id,
304 | pad_to_multiple_of=None,
305 | padding=False
306 | )
307 |
308 | # Metric
309 | def compute_metrics(eval_preds):
310 | preds, labels = eval_preds
311 | if isinstance(preds, tuple):
312 | preds = preds[0]
313 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
314 | if data_args.ignore_pad_token_for_loss:
315 | # Replace -100 in the labels as we can't decode them.
316 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
317 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
318 |
319 | score_dict = {
320 | "rouge-1": [],
321 | "rouge-2": [],
322 | "rouge-l": [],
323 | "bleu-4": []
324 | }
325 | for pred, label in zip(decoded_preds, decoded_labels):
326 | hypothesis = list(jieba.cut(pred))
327 | reference = list(jieba.cut(label))
328 | rouge = Rouge()
329 | scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
330 | result = scores[0]
331 |
332 | for k, v in result.items():
333 | score_dict[k].append(round(v["f"] * 100, 4))
334 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
335 | score_dict["bleu-4"].append(round(bleu_score * 100, 4))
336 |
337 | for k, v in score_dict.items():
338 | score_dict[k] = float(np.mean(v))
339 | return score_dict
340 |
341 | # Override the decoding parameters of Seq2SeqTrainer
342 | training_args.generation_max_length = (
343 | training_args.generation_max_length
344 | if training_args.generation_max_length is not None
345 | else data_args.val_max_target_length
346 | )
347 | training_args.generation_num_beams = (
348 | data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
349 | )
350 | # Initialize our Trainer
351 | trainer = Seq2SeqTrainer(
352 | model=model,
353 | args=training_args,
354 | train_dataset=train_dataset if training_args.do_train else None,
355 | eval_dataset=eval_dataset if training_args.do_eval else None,
356 | tokenizer=tokenizer,
357 | data_collator=data_collator,
358 | compute_metrics=compute_metrics if training_args.predict_with_generate else None,
359 | save_prefixencoder=model_args.pre_seq_len is not None
360 | )
361 |
362 | # Training
363 | if training_args.do_train:
364 | checkpoint = None
365 | if training_args.resume_from_checkpoint is not None:
366 | checkpoint = training_args.resume_from_checkpoint
367 | # elif last_checkpoint is not None:
368 | # checkpoint = last_checkpoint
369 | model.gradient_checkpointing_enable()
370 | model.enable_input_require_grads()
371 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
372 | # trainer.save_model() # Saves the tokenizer too for easy upload
373 |
374 | metrics = train_result.metrics
375 | max_train_samples = (
376 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
377 | )
378 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
379 |
380 | trainer.log_metrics("train", metrics)
381 | trainer.save_metrics("train", metrics)
382 | trainer.save_state()
383 |
384 | # Evaluation
385 | results = {}
386 | if training_args.do_eval:
387 | logger.info("*** Evaluate ***")
388 | metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=512, temperature=0.95)
389 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
390 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
391 |
392 | trainer.log_metrics("eval", metrics)
393 | trainer.save_metrics("eval", metrics)
394 |
395 | if training_args.do_predict:
396 | logger.info("*** Predict ***")
397 |
398 | predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=512, do_sample=True, top_p=0.7, temperature=0.95)
399 | metrics = predict_results.metrics
400 | max_predict_samples = (
401 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
402 | )
403 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
404 |
405 | trainer.log_metrics("predict", metrics)
406 | trainer.save_metrics("predict", metrics)
407 |
408 | if trainer.is_world_process_zero():
409 | if training_args.predict_with_generate:
410 | predictions = tokenizer.batch_decode(
411 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
412 | )
413 | predictions = [pred.strip() for pred in predictions]
414 | labels = tokenizer.batch_decode(
415 | predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
416 | )
417 | labels = [label.strip() for label in labels]
418 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
419 | with open(output_prediction_file, "w", encoding="utf-8") as writer:
420 | for p, l in zip(predictions, labels):
421 | res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
422 | writer.write(f"{res}\n")
423 | return results
424 |
425 |
426 | def _mp_fn(index):
427 | # For xla_spawn (TPUs)
428 | main()
429 |
430 |
431 | if __name__ == "__main__":
432 | main()
433 |
--------------------------------------------------------------------------------
/src/ptuning/prediction.sh:
--------------------------------------------------------------------------------
1 | PRE_SEQ_LEN=128
2 | CHECKPOINT=0523-bio_prompt_1-chatglm-6b-pt-128-2e-2-bs8-accumulation2
3 | STEP=34900
4 |
5 | CUDA_VISIBLE_DEVICES=1 python3 main.py \
6 | --do_predict \
7 | --validation_file ../../data/val_prompt.json \
8 | --test_file ../../data/test_prompt.json \
9 | --overwrite_cache \
10 | --prompt_column prompt \
11 | --response_column completion \
12 | --model_name_or_path /home/shinian.ljl/projects/ChatGLM-6B/THUDM/chatglm-6b \
13 | --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
14 | --output_dir ./output/$CHECKPOINT \
15 | --overwrite_output_dir \
16 | --max_source_length 256 \
17 | --max_target_length 256 \
18 | --per_device_eval_batch_size 1 \
19 | --predict_with_generate \
20 | --pre_seq_len $PRE_SEQ_LEN
21 |
--------------------------------------------------------------------------------
/src/ptuning/train.sh:
--------------------------------------------------------------------------------
1 | PRE_SEQ_LEN=128
2 | LR=2e-2
3 |
4 | CUDA_VISIBLE_DEVICES=1 python3 main.py \
5 | --do_train \
6 | --train_file /home/shinian.ljl/data/bio/CMedQA/train_prompt_1.json \
7 | --validation_file /home/shinian.ljl/data/bio/CMedQA/val_prompt_1.json \
8 | --prompt_column prompt \
9 | --response_column completion \
10 | --overwrite_cache \
11 | --model_name_or_path /home/shinian.ljl/projects/ChatGLM-6B/THUDM/chatglm-6b \
12 | --output_dir output/0813-bio_prompt_1-chatglm-6b-pt-$PRE_SEQ_LEN-$LR-bs8-accumulation2 \
13 | --overwrite_output_dir \
14 | --max_source_length 256 \
15 | --max_target_length 256 \
16 | --per_device_train_batch_size 8 \
17 | --per_device_eval_batch_size 8 \
18 | --gradient_accumulation_steps 2 \
19 | --predict_with_generate \
20 | --max_steps 50000 \
21 | --logging_steps 10 \
22 | --save_steps 500 \
23 | --learning_rate $LR \
24 | --pre_seq_len $PRE_SEQ_LEN \
25 | --report_to wandb
--------------------------------------------------------------------------------
/src/ptuning/trainer_seq2seq.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Any, Dict, List, Optional, Tuple, Union
16 |
17 | import torch
18 | from torch import nn
19 | from torch.utils.data import Dataset
20 |
21 | from transformers.deepspeed import is_deepspeed_zero3_enabled
22 | from trainer import Trainer
23 | from transformers.trainer_utils import PredictionOutput
24 | from transformers.utils import logging
25 |
26 |
27 | logger = logging.get_logger(__name__)
28 |
29 |
30 | class Seq2SeqTrainer(Trainer):
31 | def evaluate(
32 | self,
33 | eval_dataset: Optional[Dataset] = None,
34 | ignore_keys: Optional[List[str]] = None,
35 | metric_key_prefix: str = "eval",
36 | **gen_kwargs
37 | ) -> Dict[str, float]:
38 | """
39 | Run evaluation and returns metrics.
40 |
41 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
42 | (pass it to the init `compute_metrics` argument).
43 |
44 | You can also subclass and override this method to inject custom behavior.
45 |
46 | Args:
47 | eval_dataset (`Dataset`, *optional*):
48 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
49 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
50 | method.
51 | ignore_keys (`List[str]`, *optional*):
52 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when
53 | gathering predictions.
54 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
55 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
56 | "eval_bleu" if the prefix is `"eval"` (default)
57 | max_length (`int`, *optional*):
58 | The maximum target length to use when predicting with the generate method.
59 | num_beams (`int`, *optional*):
60 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no
61 | beam search.
62 | gen_kwargs:
63 | Additional `generate` specific kwargs.
64 |
65 | Returns:
66 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
67 | dictionary also contains the epoch number which comes from the training state.
68 | """
69 |
70 | gen_kwargs = gen_kwargs.copy()
71 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
72 | gen_kwargs["max_length"] = self.args.generation_max_length
73 | gen_kwargs["num_beams"] = (
74 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
75 | )
76 | self._gen_kwargs = gen_kwargs
77 |
78 | return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
79 |
80 | def predict(
81 | self,
82 | test_dataset: Dataset,
83 | ignore_keys: Optional[List[str]] = None,
84 | metric_key_prefix: str = "test",
85 | **gen_kwargs
86 | ) -> PredictionOutput:
87 | """
88 | Run prediction and returns predictions and potential metrics.
89 |
90 | Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
91 | will also return metrics, like in `evaluate()`.
92 |
93 | Args:
94 | test_dataset (`Dataset`):
95 | Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
96 | `model.forward()` method are automatically removed. Has to implement the method `__len__`
97 | ignore_keys (`List[str]`, *optional*):
98 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when
99 | gathering predictions.
100 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
101 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
102 | "eval_bleu" if the prefix is `"eval"` (default)
103 | max_length (`int`, *optional*):
104 | The maximum target length to use when predicting with the generate method.
105 | num_beams (`int`, *optional*):
106 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no
107 | beam search.
108 | gen_kwargs:
109 | Additional `generate` specific kwargs.
110 |
111 |
112 |
113 | If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
114 | padding in a token classification task) the predictions will be padded (on the right) to allow for
115 | concatenation into one array. The padding index is -100.
116 |
117 |
118 |
119 | Returns: *NamedTuple* A namedtuple with the following keys:
120 |
121 | - predictions (`np.ndarray`): The predictions on `test_dataset`.
122 | - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
123 | - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
124 | labels).
125 | """
126 |
127 | gen_kwargs = gen_kwargs.copy()
128 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
129 | gen_kwargs["max_length"] = self.args.generation_max_length
130 | gen_kwargs["num_beams"] = (
131 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
132 | )
133 | self._gen_kwargs = gen_kwargs
134 |
135 |
136 | return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
137 |
138 | def prediction_step(
139 | self,
140 | model: nn.Module,
141 | inputs: Dict[str, Union[torch.Tensor, Any]],
142 | prediction_loss_only: bool,
143 | ignore_keys: Optional[List[str]] = None,
144 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
145 | """
146 | Perform an evaluation step on `model` using `inputs`.
147 |
148 | Subclass and override to inject custom behavior.
149 |
150 | Args:
151 | model (`nn.Module`):
152 | The model to evaluate.
153 | inputs (`Dict[str, Union[torch.Tensor, Any]]`):
154 | The inputs and targets of the model.
155 |
156 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
157 | argument `labels`. Check your model's documentation for all accepted arguments.
158 | prediction_loss_only (`bool`):
159 | Whether or not to return the loss only.
160 |
161 | Return:
162 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
163 | labels (each being optional).
164 | """
165 |
166 | if not self.args.predict_with_generate or prediction_loss_only:
167 | return super().prediction_step(
168 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
169 | )
170 |
171 | has_labels = "labels" in inputs
172 | inputs = self._prepare_inputs(inputs)
173 |
174 | # XXX: adapt synced_gpus for fairscale as well
175 | gen_kwargs = self._gen_kwargs.copy()
176 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
177 | gen_kwargs["max_length"] = self.model.config.max_length
178 | gen_kwargs["num_beams"] = (
179 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
180 | )
181 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
182 | gen_kwargs["synced_gpus"] = (
183 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
184 | )
185 |
186 | if "attention_mask" in inputs:
187 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
188 | if "position_ids" in inputs:
189 | gen_kwargs["position_ids"] = inputs.get("position_ids", None)
190 | if "global_attention_mask" in inputs:
191 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
192 |
193 | # prepare generation inputs
194 | # some encoder-decoder models can have varying encoder's and thus
195 | # varying model input names
196 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
197 | generation_inputs = inputs[self.model.encoder.main_input_name]
198 | else:
199 | generation_inputs = inputs[self.model.main_input_name]
200 |
201 | gen_kwargs["input_ids"] = generation_inputs
202 | gen_kwargs["num_return_sequences"] = gen_kwargs["num_beams"]
203 | generated_tokens = self.model.generate(**gen_kwargs)
204 | generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:]
205 |
206 | # in case the batch is shorter than max length, the output should be padded
207 | if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
208 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
209 | elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
210 | gen_kwargs["max_new_tokens"] + 1
211 | ):
212 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
213 |
214 | loss = None
215 |
216 | if self.args.prediction_loss_only:
217 | return (loss, None, None)
218 |
219 | if has_labels:
220 | labels = inputs["labels"]
221 | if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
222 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
223 | elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
224 | gen_kwargs["max_new_tokens"] + 1
225 | ):
226 | labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
227 | else:
228 | labels = None
229 |
230 | return (loss, generated_tokens, labels)
231 |
232 | def _pad_tensors_to_max_len(self, tensor, max_length):
233 | if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
234 | # If PAD token is not defined at least EOS token has to be defined
235 | pad_token_id = (
236 | self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
237 | )
238 | else:
239 | if self.model.config.pad_token_id is not None:
240 | pad_token_id = self.model.config.pad_token_id
241 | else:
242 | raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
243 |
244 | padded_tensor = pad_token_id * torch.ones(
245 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
246 | )
247 | padded_tensor[:, : tensor.shape[-1]] = tensor
248 | return padded_tensor
249 |
--------------------------------------------------------------------------------
/src/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.18.0
2 | aiofiles==23.1.0
3 | aiohttp==3.8.4
4 | aiosignal==1.3.1
5 | altair==4.2.2
6 | anyio==3.6.2
7 | appdirs==1.4.4
8 | async-timeout==4.0.2
9 | attrs==23.1.0
10 | certifi==2022.12.7
11 | charset-normalizer==3.1.0
12 | click==8.1.3
13 | cmake==3.26.3
14 | contourpy==1.0.7
15 | cpm-kernels==1.0.11
16 | cycler==0.11.0
17 | datasets==2.12.0
18 | dill==0.3.6
19 | docker-pycreds==0.4.0
20 | entrypoints==0.4
21 | fastapi==0.95.1
22 | ffmpy==0.3.0
23 | filelock==3.12.0
24 | fonttools==4.39.3
25 | frozenlist==1.3.3
26 | fsspec==2023.4.0
27 | gitdb==4.0.10
28 | GitPython==3.1.31
29 | gradio==3.27.0
30 | gradio_client==0.1.3
31 | h11==0.14.0
32 | httpcore==0.17.0
33 | httpx==0.24.0
34 | huggingface-hub==0.13.4
35 | idna==3.4
36 | importlib-metadata==6.6.0
37 | importlib-resources==5.12.0
38 | jieba==0.42.1
39 | Jinja2==3.1.2
40 | joblib==1.2.0
41 | jsonschema==4.17.3
42 | kiwisolver==1.4.4
43 | latex2mathml==3.75.2
44 | linkify-it-py==2.0.0
45 | lit==16.0.1
46 | Markdown==3.4.3
47 | markdown-it-py==2.2.0
48 | MarkupSafe==2.1.2
49 | matplotlib==3.7.1
50 | mdit-py-plugins==0.3.3
51 | mdtex2html==1.2.0
52 | mdurl==0.1.2
53 | mpmath==1.3.0
54 | multidict==6.0.4
55 | multiprocess==0.70.14
56 | networkx==3.1
57 | nltk==3.8.1
58 | numpy==1.24.3
59 | nvidia-cublas-cu11==11.10.3.66
60 | nvidia-cuda-cupti-cu11==11.7.101
61 | nvidia-cuda-nvrtc-cu11==11.7.99
62 | nvidia-cuda-runtime-cu11==11.7.99
63 | nvidia-cudnn-cu11==8.5.0.96
64 | nvidia-cufft-cu11==10.9.0.58
65 | nvidia-curand-cu11==10.2.10.91
66 | nvidia-cusolver-cu11==11.4.0.1
67 | nvidia-cusparse-cu11==11.7.4.91
68 | nvidia-nccl-cu11==2.14.3
69 | nvidia-nvtx-cu11==11.7.91
70 | orjson==3.8.10
71 | packaging==23.1
72 | pandas==2.0.0
73 | pathtools==0.1.2
74 | peft==0.4.0
75 | Pillow==9.5.0
76 | protobuf==4.22.3
77 | psutil==5.9.5
78 | pyarrow==12.0.0
79 | pydantic==1.10.7
80 | pydub==0.25.1
81 | pyparsing==3.0.9
82 | pyrsistent==0.19.3
83 | python-dateutil==2.8.2
84 | python-multipart==0.0.6
85 | pytz==2023.3
86 | PyYAML==6.0
87 | regex==2023.3.23
88 | requests==2.28.2
89 | responses==0.18.0
90 | rouge-chinese==1.0.3
91 | safetensors==0.3.1
92 | semantic-version==2.10.0
93 | sentencepiece==0.1.98
94 | sentry-sdk==1.23.0
95 | setproctitle==1.3.2
96 | six==1.16.0
97 | smmap==5.0.0
98 | sniffio==1.3.0
99 | starlette==0.26.1
100 | sympy==1.11.1
101 | tokenizers==0.13.3
102 | toolz==0.12.0
103 | torch==2.0.0
104 | tqdm==4.65.0
105 | transformers==4.27.1
106 | triton==2.0.0
107 | typing_extensions==4.5.0
108 | tzdata==2023.3
109 | uc-micro-py==1.0.1
110 | urllib3==1.26.15
111 | uvicorn==0.21.1
112 | wandb==0.15.2
113 | websockets==11.0.2
114 | xxhash==3.2.0
115 | yarl==1.9.1
116 | zipp==3.15.0
117 |
--------------------------------------------------------------------------------