├── .gitignore ├── README.md ├── README_EN.md ├── datasets └── PromptCBLUE │ └── toy_examples │ ├── dev.json │ ├── dev_structured.json │ ├── results.json │ ├── test.json │ ├── test_predictions.json │ └── train.json ├── peft ├── __init__.py ├── import_utils.py ├── mapping.py ├── peft_model.py ├── tuners │ ├── __init__.py │ ├── adalora.py │ ├── adaption_prompt.py │ ├── lora.py │ ├── p_tuning.py │ ├── prefix_tuning.py │ └── prompt_tuning.py └── utils │ ├── __init__.py │ ├── config.py │ ├── other.py │ └── save_and_load.py ├── pics ├── dingding_groups.jpg ├── promptCBLUE_banner_v0.png ├── promptCBLUE_en_banner_v0.png └── wechat_qrcode.jpg ├── requirements.txt └── src ├── README.md ├── data ├── CBLUE任务改造说明与举例.md ├── templates.json ├── templates_augment.json └── 结构化预测结果格式说明.md ├── download_checkpoints.py ├── evaluation ├── README.txt ├── evaluate.py ├── evaluators.py ├── input_param.json ├── post_generate_process.py └── text2dt_eval_func.py ├── ft_chatglm_lora ├── arguments.py ├── evaluate.sh ├── main.py ├── train.sh ├── trainer.py └── trainer_seq2seq.py ├── ft_chatglm_ptuning ├── arguments.py ├── config.json ├── configuration_chatglm.py ├── evaluate.sh ├── main.py ├── modeling_chatglm.py ├── quantization.py ├── test_modeling_chatglm.py ├── tokenization_chatglm.py ├── tokenizer_config.json ├── train.sh ├── trainer.py └── trainer_seq2seq.py └── ft_llama_lora ├── merge_llama_with_chinese_lora.py ├── run_clm_pt_with_peft.py ├── run_train.sh └── vllm_serving ├── launch_vllm.py ├── llm_engine.py ├── merge_llama_with_lora.py ├── utils.py └── web_service_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .idea/* 3 | 4 | __pycache__/* 5 | 6 | catboost_info/* 7 | 8 | references/* 9 | results/* 10 | internal/* 11 | 12 | experiments/* 13 | assets/* 14 | tmps/* 15 | 16 | wandb/* 17 | 18 | tmp/* 19 | datasets/* 20 | datasets/PromptCBLUE/test_a_open/* 21 | src/eval_chatgpt/* 22 | 23 | *.log 24 | 25 | *.wav 26 | *.mp3 27 | 28 | *.pkl 29 | 30 | 31 | 32 | *.png 33 | 34 | *.jpg 35 | *.xlsx 36 | *.csv 37 | 38 | catboost_info 39 | 40 | 41 | -------------------------------------------------------------------------------- /README_EN.md: -------------------------------------------------------------------------------- 1 | [**中文**](./README.md) | [**English**](./README_EN.md) 2 | 3 |

4 |
5 | 6 |
7 |

8 |

9 | GitHub 10 | GitHub top language 11 |

12 | 13 | 14 | The emergence of large language models (LLMs) represented by ChatGPT and GPT-4 has sparked a new wave of research in the field of natural language processing, demonstrating capabilities similar to those of general artificial intelligence (AGI) and attracting widespread attention from the industry. With the prevalence of LLMs, almost all NLP tasks have been transformed into prompt-based language generation tasks. However, in the Chinese medical NLP community, there is still no unified task-based benchmark for evaluation. 15 | 16 | 17 | To promote the developments and applications of LLMs in the medical field, Professor Xiaoling Wang's team from East China Normal University, in collaboration with Alibaba Tianchi Platform, Fudan University, Huashan Hospital affiliated to Fudan University, Northeastern University, Harbin Institute of Technology (Shenzhen), Peng Cheng Laboratory, and Tongji University, has launched the **PromptCBLUE** evaluation benchmark by modifying the [CBLUE](https://tianchi.aliyun.com/dataset/95414) benchmark. This benchmark has converted all 16 different medical NLP tasks into prompt-based language generation tasks, creating the first Chinese medical LLM evaluation benchmark. PromptCBLUE is one of the evaluation tasks for the [CCKS-2023](https://sigkg.cn/ccks2023/evaluation) conference and has been launched for open evaluation on the Alibaba Tianchi competition platform. Industrial practitioners, students and researchers are welcome to register and participate in the competition. 18 | 19 | In consideration of the potential involvement of commercial data in LLM training and the limitations posed by various external factors on the open-sourcing of large language models, we have opened two tracks for the PromptCBLUE evaluation: 20 | - General track: This track accepts evaluations from enterprises, universities, open-source communities, research teams, or individuals who have developed their own LLMs. Participants are not required to open-source their models. The evaluation website for this track is available at [CCKS2023-PromptCBLUE General Track](https://tianchi.aliyun.com/competition/entrance/532085/introduction). 21 | - Open-source track: This track is open to all participating teams who must use an open-source large-scale model framework and only train/fine-tune using open-source datasets or datasets that can be submitted to the competition organizer for review. The evaluation website for this track is available at [CCKS2023-PromptCBLUE Open-source Track](https://tianchi.aliyun.com/competition/entrance/532084/introduction). 22 | 23 | 24 | To assist in the enhancement of LLM's abilities in the medical field, we are open-sourcing the following data/model resources: 25 | - 🚀 [ChatMed_Consult_Dataset](https://huggingface.co/datasets/michaelwzhu/ChatMed_Consult_Dataset):A Chinese medical online consultation dataset containing 500k+ online consultation queries and responses made by ChatGPT. 26 | - 🚀 [ChatMed-Consult model](https://huggingface.co/michaelwzhu/ChatMed-Consult): A large Chinese medical consultation model fine-tuned on [ChatMed_Consult_Dataset](https://huggingface.co/datasets/michaelwzhu/ChatMed_Consult_Dataset). The model is based on the [LlaMA-7b](https://github.com/facebookresearch/llama) merged with LoRA weights from [Chinese-LlaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca). All data and code are publicly available at [ChatMed项目](https://github.com/michael-wzhu/ChatMed). 27 | - ⏳ [ChatMed_TCM_Dataset](https://huggingface.co/datasets/michaelwzhu/ChatMed_TCM_Dataset). A dataset of instructions for traditional Chinese medicine (TCM) with 26k+ samples generated using the [entity-centric self-instruct method](https://github.com/michael-wzhu/ChatMed/blob/main/src/) based on our open-sourced [TCM knowledge graph](https://github.com/ywjawmw/TCM_KG) and ChatGPT responses. 28 | - ⏳ [ChatMed-TCM model](https://huggingface.co/michaelwzhu/ChatMed-TCM) : Empowering TCM inheritance with LLMs. This model is also based on LlaMA and fine-tuned on the [ChatMed_TCM_Dataset](https://huggingface.co/datasets/michaelwzhu/ChatMed_TCM_Dataset) with LoRA. 29 | 30 | ---- 31 | 32 | [Text2DT](https://github.com/michael-wzhu/Text2DT_Baseline) | [ChatMed_Consult_Dataset](https://huggingface.co/datasets/michaelwzhu/ChatMed_Consult_Dataset) | [ChatMed-Consult model](https://huggingface.co/michaelwzhu/ChatMed-Consult) | [ChatMed_TCM_Dataset](https://huggingface.co/datasets/michaelwzhu/ChatMed_TCM_Dataset) | [ChatMed-TCM model](https://huggingface.co/michaelwzhu/ChatMed-TCM) 33 | 34 | 35 | ## Updates 36 | 37 | 2023/05/12 English README is out! -------------------------------------------------------------------------------- /datasets/PromptCBLUE/toy_examples/results.json: -------------------------------------------------------------------------------- 1 | { 2 | "CMeEE-V2": [ 3 | { 4 | "sample_id": "test-13436", 5 | "answer": [] 6 | }, 7 | { 8 | "sample_id": "test-21108", 9 | "answer": [] 10 | }, 11 | { 12 | "sample_id": "test-29977", 13 | "answer": [] 14 | }, 15 | { 16 | "sample_id": "test-2328", 17 | "answer": [] 18 | }, 19 | { 20 | "sample_id": "test-1202", 21 | "answer": [] 22 | } 23 | ], 24 | "CMeIE": [ 25 | { 26 | "sample_id": "test-27030", 27 | "answer": [] 28 | }, 29 | { 30 | "sample_id": "test-2375", 31 | "answer": [] 32 | }, 33 | { 34 | "sample_id": "test-13853", 35 | "answer": [] 36 | }, 37 | { 38 | "sample_id": "test-21241", 39 | "answer": [] 40 | }, 41 | { 42 | "sample_id": "test-10386", 43 | "answer": [] 44 | } 45 | ], 46 | "CHIP-CDN": [ 47 | { 48 | "sample_id": "test-12089", 49 | "answer": [] 50 | }, 51 | { 52 | "sample_id": "test-82943", 53 | "answer": [] 54 | }, 55 | { 56 | "sample_id": "test-75472", 57 | "answer": [] 58 | }, 59 | { 60 | "sample_id": "test-123121", 61 | "answer": [] 62 | }, 63 | { 64 | "sample_id": "test-85974", 65 | "answer": [] 66 | } 67 | ], 68 | "CHIP-CDEE": [ 69 | { 70 | "sample_id": "test-2598", 71 | "answer": [] 72 | }, 73 | { 74 | "sample_id": "test-3240", 75 | "answer": [] 76 | }, 77 | { 78 | "sample_id": "test-4341", 79 | "answer": [] 80 | }, 81 | { 82 | "sample_id": "test-6453", 83 | "answer": [] 84 | }, 85 | { 86 | "sample_id": "test-1666", 87 | "answer": [] 88 | } 89 | ], 90 | "CHIP-STS": [ 91 | { 92 | "sample_id": "test-12693", 93 | "answer": "" 94 | }, 95 | { 96 | "sample_id": "test-11036", 97 | "answer": "" 98 | }, 99 | { 100 | "sample_id": "test-56872", 101 | "answer": "" 102 | }, 103 | { 104 | "sample_id": "test-5225", 105 | "answer": "" 106 | }, 107 | { 108 | "sample_id": "test-13728", 109 | "answer": "" 110 | } 111 | ], 112 | "CHIP-CTC": [ 113 | { 114 | "sample_id": "test-43923", 115 | "answer": "" 116 | }, 117 | { 118 | "sample_id": "test-20172", 119 | "answer": "" 120 | }, 121 | { 122 | "sample_id": "test-43291", 123 | "answer": "" 124 | }, 125 | { 126 | "sample_id": "test-43152", 127 | "answer": "" 128 | }, 129 | { 130 | "sample_id": "test-17740", 131 | "answer": "" 132 | } 133 | ], 134 | "CHIP-MDCFNPC": [ 135 | { 136 | "sample_id": "test-125634", 137 | "answer": [] 138 | }, 139 | { 140 | "sample_id": "test-114430", 141 | "answer": [] 142 | }, 143 | { 144 | "sample_id": "test-230623", 145 | "answer": [] 146 | }, 147 | { 148 | "sample_id": "test-183330", 149 | "answer": [] 150 | }, 151 | { 152 | "sample_id": "test-33282", 153 | "answer": [] 154 | } 155 | ], 156 | "KUAKE-IR": [ 157 | { 158 | "sample_id": "test-5149", 159 | "answer": "" 160 | }, 161 | { 162 | "sample_id": "test-11078", 163 | "answer": "" 164 | }, 165 | { 166 | "sample_id": "test-2952", 167 | "answer": "" 168 | }, 169 | { 170 | "sample_id": "test-16745", 171 | "answer": "" 172 | }, 173 | { 174 | "sample_id": "test-47", 175 | "answer": "" 176 | } 177 | ], 178 | "KUAKE-QIC": [ 179 | { 180 | "sample_id": "test-9060", 181 | "answer": "" 182 | }, 183 | { 184 | "sample_id": "test-7590", 185 | "answer": "" 186 | }, 187 | { 188 | "sample_id": "test-2854", 189 | "answer": "" 190 | }, 191 | { 192 | "sample_id": "test-5419", 193 | "answer": "" 194 | }, 195 | { 196 | "sample_id": "test-2560", 197 | "answer": "" 198 | } 199 | ], 200 | "KUAKE-QQR": [ 201 | { 202 | "sample_id": "test-6143", 203 | "answer": "" 204 | }, 205 | { 206 | "sample_id": "test-6781", 207 | "answer": "" 208 | }, 209 | { 210 | "sample_id": "test-4876", 211 | "answer": "" 212 | }, 213 | { 214 | "sample_id": "test-7980", 215 | "answer": "" 216 | }, 217 | { 218 | "sample_id": "test-10087", 219 | "answer": "" 220 | } 221 | ], 222 | "KUAKE-QTR": [ 223 | { 224 | "sample_id": "test-17079", 225 | "answer": "" 226 | }, 227 | { 228 | "sample_id": "test-33477", 229 | "answer": "" 230 | }, 231 | { 232 | "sample_id": "test-6747", 233 | "answer": "" 234 | }, 235 | { 236 | "sample_id": "test-22159", 237 | "answer": "" 238 | }, 239 | { 240 | "sample_id": "test-5848", 241 | "answer": "" 242 | } 243 | ], 244 | "MedDG": [ 245 | { 246 | "sample_id": "test-75967", 247 | "answer": "" 248 | }, 249 | { 250 | "sample_id": "test-129", 251 | "answer": "" 252 | }, 253 | { 254 | "sample_id": "test-85396", 255 | "answer": "" 256 | }, 257 | { 258 | "sample_id": "test-44400", 259 | "answer": "" 260 | }, 261 | { 262 | "sample_id": "test-61086", 263 | "answer": "" 264 | } 265 | ], 266 | "IMCS-V2-MRG": [ 267 | { 268 | "sample_id": "test-3984", 269 | "answer": {} 270 | }, 271 | { 272 | "sample_id": "test-2303", 273 | "answer": {} 274 | }, 275 | { 276 | "sample_id": "test-1442", 277 | "answer": {} 278 | }, 279 | { 280 | "sample_id": "test-3975", 281 | "answer": {} 282 | }, 283 | { 284 | "sample_id": "test-5417", 285 | "answer": {} 286 | } 287 | ], 288 | "IMCS-V2-NER": [ 289 | { 290 | "sample_id": "test-30182", 291 | "answer": [] 292 | }, 293 | { 294 | "sample_id": "test-79788", 295 | "answer": [] 296 | }, 297 | { 298 | "sample_id": "test-32048", 299 | "answer": [] 300 | }, 301 | { 302 | "sample_id": "test-37196", 303 | "answer": [] 304 | }, 305 | { 306 | "sample_id": "test-30619", 307 | "answer": [] 308 | } 309 | ], 310 | "IMCS-V2-DAC": [ 311 | { 312 | "sample_id": "test-130825", 313 | "answer": "" 314 | }, 315 | { 316 | "sample_id": "test-13614", 317 | "answer": "" 318 | }, 319 | { 320 | "sample_id": "test-97745", 321 | "answer": "" 322 | }, 323 | { 324 | "sample_id": "test-152831", 325 | "answer": "" 326 | }, 327 | { 328 | "sample_id": "test-89272", 329 | "answer": "" 330 | } 331 | ], 332 | "IMCS-V2-SR": [ 333 | { 334 | "sample_id": "test-16623", 335 | "answer": [] 336 | }, 337 | { 338 | "sample_id": "test-27789", 339 | "answer": [] 340 | }, 341 | { 342 | "sample_id": "test-38502", 343 | "answer": [] 344 | }, 345 | { 346 | "sample_id": "test-3890", 347 | "answer": [] 348 | }, 349 | { 350 | "sample_id": "test-17282", 351 | "answer": [] 352 | } 353 | ] 354 | } -------------------------------------------------------------------------------- /peft/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | __version__ = "0.3.0.dev0" 21 | 22 | from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_config, get_peft_model 23 | from .peft_model import ( 24 | PeftModel, 25 | PeftModelForCausalLM, 26 | PeftModelForSeq2SeqLM, 27 | PeftModelForSequenceClassification, 28 | PeftModelForTokenClassification, 29 | ) 30 | from .tuners import ( 31 | AdaptionPromptConfig, 32 | AdaptionPromptModel, 33 | LoraConfig, 34 | LoraModel, 35 | AdaLoraConfig, 36 | AdaLoraModel, 37 | PrefixEncoder, 38 | PrefixTuningConfig, 39 | PromptEmbedding, 40 | PromptEncoder, 41 | PromptEncoderConfig, 42 | PromptEncoderReparameterizationType, 43 | PromptTuningConfig, 44 | PromptTuningInit, 45 | ) 46 | from .utils import ( 47 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, 48 | PeftConfig, 49 | PeftType, 50 | PromptLearningConfig, 51 | TaskType, 52 | bloom_model_postprocess_past_key_value, 53 | get_peft_model_state_dict, 54 | prepare_model_for_int8_training, 55 | set_peft_model_state_dict, 56 | shift_tokens_right, 57 | ) 58 | -------------------------------------------------------------------------------- /peft/import_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import importlib 16 | 17 | 18 | def is_bnb_available(): 19 | return importlib.util.find_spec("bitsandbytes") is not None 20 | -------------------------------------------------------------------------------- /peft/mapping.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .peft_model import ( 17 | PeftModel, 18 | PeftModelForCausalLM, 19 | PeftModelForSeq2SeqLM, 20 | PeftModelForSequenceClassification, 21 | PeftModelForTokenClassification, 22 | ) 23 | from .tuners import ( 24 | AdaLoraConfig, 25 | AdaptionPromptConfig, 26 | LoraConfig, 27 | PrefixTuningConfig, 28 | PromptEncoderConfig, 29 | PromptTuningConfig, 30 | ) 31 | from .utils import PromptLearningConfig 32 | 33 | 34 | MODEL_TYPE_TO_PEFT_MODEL_MAPPING = { 35 | "SEQ_CLS": PeftModelForSequenceClassification, 36 | "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM, 37 | "CAUSAL_LM": PeftModelForCausalLM, 38 | "TOKEN_CLS": PeftModelForTokenClassification, 39 | } 40 | 41 | PEFT_TYPE_TO_CONFIG_MAPPING = { 42 | "ADAPTION_PROMPT": AdaptionPromptConfig, 43 | "PROMPT_TUNING": PromptTuningConfig, 44 | "PREFIX_TUNING": PrefixTuningConfig, 45 | "P_TUNING": PromptEncoderConfig, 46 | "LORA": LoraConfig, 47 | "ADALORA": AdaLoraConfig, 48 | } 49 | 50 | 51 | def get_peft_config(config_dict): 52 | """ 53 | Returns a Peft config object from a dictionary. 54 | 55 | Args: 56 | config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters. 57 | """ 58 | 59 | return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict) 60 | 61 | 62 | def _prepare_prompt_learning_config(peft_config, model_config): 63 | if peft_config.num_layers is None: 64 | if "num_hidden_layers" in model_config: 65 | num_layers = model_config["num_hidden_layers"] 66 | elif "num_layers" in model_config: 67 | num_layers = model_config["num_layers"] 68 | elif "n_layer" in model_config: 69 | num_layers = model_config["n_layer"] 70 | else: 71 | raise ValueError("Please specify `num_layers` in `peft_config`") 72 | peft_config.num_layers = num_layers 73 | 74 | if peft_config.token_dim is None: 75 | if "hidden_size" in model_config: 76 | token_dim = model_config["hidden_size"] 77 | elif "n_embd" in model_config: 78 | token_dim = model_config["n_embd"] 79 | elif "d_model" in model_config: 80 | token_dim = model_config["d_model"] 81 | else: 82 | raise ValueError("Please specify `token_dim` in `peft_config`") 83 | peft_config.token_dim = token_dim 84 | 85 | if peft_config.num_attention_heads is None: 86 | if "num_attention_heads" in model_config: 87 | num_attention_heads = model_config["num_attention_heads"] 88 | elif "n_head" in model_config: 89 | num_attention_heads = model_config["n_head"] 90 | elif "num_heads" in model_config: 91 | num_attention_heads = model_config["num_heads"] 92 | elif "encoder_attention_heads" in model_config: 93 | num_attention_heads = model_config["encoder_attention_heads"] 94 | else: 95 | raise ValueError("Please specify `num_attention_heads` in `peft_config`") 96 | peft_config.num_attention_heads = num_attention_heads 97 | 98 | if getattr(peft_config, "encoder_hidden_size", None) is None: 99 | setattr(peft_config, "encoder_hidden_size", token_dim) 100 | 101 | return peft_config 102 | 103 | 104 | def get_peft_model(model, peft_config): 105 | """ 106 | Returns a Peft model object from a model and a config. 107 | 108 | Args: 109 | model ([`transformers.PreTrainedModel`]): Model to be wrapped. 110 | peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model. 111 | """ 112 | model_config = model.config.to_dict() if hasattr(model.config, "to_dict") else model.config 113 | peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) 114 | if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance( 115 | peft_config, PromptLearningConfig 116 | ): 117 | return PeftModel(model, peft_config) 118 | if isinstance(peft_config, PromptLearningConfig): 119 | peft_config = _prepare_prompt_learning_config(peft_config, model_config) 120 | return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config) 121 | -------------------------------------------------------------------------------- /peft/tuners/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel 21 | from .lora import LoraConfig, LoraModel 22 | from .adalora import AdaLoraConfig, AdaLoraModel 23 | from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType 24 | from .prefix_tuning import PrefixEncoder, PrefixTuningConfig 25 | from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit 26 | -------------------------------------------------------------------------------- /peft/tuners/p_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import enum 17 | import warnings 18 | from dataclasses import dataclass, field 19 | from typing import Union 20 | 21 | import torch 22 | 23 | from ..utils import PeftType, PromptLearningConfig 24 | 25 | 26 | class PromptEncoderReparameterizationType(str, enum.Enum): 27 | MLP = "MLP" 28 | LSTM = "LSTM" 29 | 30 | 31 | @dataclass 32 | class PromptEncoderConfig(PromptLearningConfig): 33 | """ 34 | This is the configuration class to store the configuration of a [`PromptEncoder`]. 35 | 36 | Args: 37 | encoder_reparameterization_type (Union[[`PromptEncoderReparameterizationType`], `str`]): 38 | The type of reparameterization to use. 39 | encoder_hidden_size (`int`): The hidden size of the prompt encoder. 40 | encoder_num_layers (`int`): The number of layers of the prompt encoder. 41 | encoder_dropout (`float`): The dropout probability of the prompt encoder. 42 | """ 43 | 44 | encoder_reparameterization_type: Union[str, PromptEncoderReparameterizationType] = field( 45 | default=PromptEncoderReparameterizationType.MLP, 46 | metadata={"help": "How to reparameterize the prompt encoder"}, 47 | ) 48 | encoder_hidden_size: int = field( 49 | default=None, 50 | metadata={"help": "The hidden size of the prompt encoder"}, 51 | ) 52 | encoder_num_layers: int = field( 53 | default=2, 54 | metadata={"help": "The number of layers of the prompt encoder"}, 55 | ) 56 | encoder_dropout: float = field( 57 | default=0.0, 58 | metadata={"help": "The dropout of the prompt encoder"}, 59 | ) 60 | 61 | def __post_init__(self): 62 | self.peft_type = PeftType.P_TUNING 63 | 64 | 65 | # Based on https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/modules/common/prompt_encoder.py 66 | # with some refactor 67 | class PromptEncoder(torch.nn.Module): 68 | """ 69 | The prompt encoder network that is used to generate the virtual token embeddings for p-tuning. 70 | 71 | Args: 72 | config ([`PromptEncoderConfig`]): The configuration of the prompt encoder. 73 | 74 | Example: 75 | 76 | ```py 77 | >>> from peft import PromptEncoder, PromptEncoderConfig 78 | 79 | >>> config = PromptEncoderConfig( 80 | ... peft_type="P_TUNING", 81 | ... task_type="SEQ_2_SEQ_LM", 82 | ... num_virtual_tokens=20, 83 | ... token_dim=768, 84 | ... num_transformer_submodules=1, 85 | ... num_attention_heads=12, 86 | ... num_layers=12, 87 | ... encoder_reparameterization_type="MLP", 88 | ... encoder_hidden_size=768, 89 | ... ) 90 | 91 | >>> prompt_encoder = PromptEncoder(config) 92 | ``` 93 | 94 | **Attributes**: 95 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt encoder. 96 | - **mlp_head** (`torch.nn.Sequential`) -- The MLP head of the prompt encoder if `inference_mode=False`. 97 | - **lstm_head** (`torch.nn.LSTM`) -- The LSTM head of the prompt encoder if `inference_mode=False` and 98 | `encoder_reparameterization_type="LSTM"`. 99 | - **token_dim** (`int`) -- The hidden embedding dimension of the base transformer model. 100 | - **input_size** (`int`) -- The input size of the prompt encoder. 101 | - **output_size** (`int`) -- The output size of the prompt encoder. 102 | - **hidden_size** (`int`) -- The hidden size of the prompt encoder. 103 | - **total_virtual_tokens** (`int`): The total number of virtual tokens of the 104 | prompt encoder. 105 | - **encoder_type** (Union[[`PromptEncoderReparameterizationType`], `str`]): The encoder type of the prompt 106 | encoder. 107 | 108 | 109 | Input shape: (`batch_size`, `total_virtual_tokens`) 110 | 111 | Output shape: (`batch_size`, `total_virtual_tokens`, `token_dim`) 112 | """ 113 | 114 | def __init__(self, config): 115 | super().__init__() 116 | self.token_dim = config.token_dim 117 | self.input_size = self.token_dim 118 | self.output_size = self.token_dim 119 | self.hidden_size = config.encoder_hidden_size 120 | self.total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules 121 | self.encoder_type = config.encoder_reparameterization_type 122 | 123 | # embedding 124 | self.embedding = torch.nn.Embedding(self.total_virtual_tokens, self.token_dim) 125 | if not config.inference_mode: 126 | if self.encoder_type == PromptEncoderReparameterizationType.LSTM: 127 | lstm_dropout = config.encoder_dropout 128 | num_layers = config.encoder_num_layers 129 | # LSTM 130 | self.lstm_head = torch.nn.LSTM( 131 | input_size=self.input_size, 132 | hidden_size=self.hidden_size, 133 | num_layers=num_layers, 134 | dropout=lstm_dropout, 135 | bidirectional=True, 136 | batch_first=True, 137 | ) 138 | 139 | self.mlp_head = torch.nn.Sequential( 140 | torch.nn.Linear(self.hidden_size * 2, self.hidden_size * 2), 141 | torch.nn.ReLU(), 142 | torch.nn.Linear(self.hidden_size * 2, self.output_size), 143 | ) 144 | 145 | elif self.encoder_type == PromptEncoderReparameterizationType.MLP: 146 | warnings.warn( 147 | f"for {self.encoder_type}, the `encoder_num_layers` is ignored. Exactly 2 MLP layers are used." 148 | ) 149 | layers = [ 150 | torch.nn.Linear(self.input_size, self.hidden_size), 151 | torch.nn.ReLU(), 152 | torch.nn.Linear(self.hidden_size, self.hidden_size), 153 | torch.nn.ReLU(), 154 | torch.nn.Linear(self.hidden_size, self.output_size), 155 | ] 156 | self.mlp_head = torch.nn.Sequential(*layers) 157 | 158 | else: 159 | raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.") 160 | 161 | def forward(self, indices): 162 | input_embeds = self.embedding(indices) 163 | if self.encoder_type == PromptEncoderReparameterizationType.LSTM: 164 | output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]) 165 | elif self.encoder_type == PromptEncoderReparameterizationType.MLP: 166 | output_embeds = self.mlp_head(input_embeds) 167 | else: 168 | raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.") 169 | 170 | return output_embeds 171 | -------------------------------------------------------------------------------- /peft/tuners/prefix_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from dataclasses import dataclass, field 18 | 19 | import torch 20 | 21 | from ..utils import PeftType, PromptLearningConfig 22 | 23 | 24 | @dataclass 25 | class PrefixTuningConfig(PromptLearningConfig): 26 | """ 27 | This is the configuration class to store the configuration of a [`PrefixEncoder`]. 28 | 29 | Args: 30 | encoder_hidden_size (`int`): The hidden size of the prompt encoder. 31 | prefix_projection (`bool`): Whether to project the prefix embeddings. 32 | """ 33 | 34 | encoder_hidden_size: int = field( 35 | default=None, 36 | metadata={"help": "The hidden size of the encoder"}, 37 | ) 38 | prefix_projection: bool = field( 39 | default=False, 40 | metadata={"help": "Whether to project the prefix tokens"}, 41 | ) 42 | 43 | def __post_init__(self): 44 | self.peft_type = PeftType.PREFIX_TUNING 45 | 46 | 47 | # Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py 48 | # with some refactor 49 | class PrefixEncoder(torch.nn.Module): 50 | r""" 51 | The `torch.nn` model to encode the prefix. 52 | 53 | Args: 54 | config ([`PrefixTuningConfig`]): The configuration of the prefix encoder. 55 | 56 | Example: 57 | 58 | ```py 59 | >>> from peft import PrefixEncoder, PrefixTuningConfig 60 | 61 | >>> config = PrefixTuningConfig( 62 | ... peft_type="PREFIX_TUNING", 63 | ... task_type="SEQ_2_SEQ_LM", 64 | ... num_virtual_tokens=20, 65 | ... token_dim=768, 66 | ... num_transformer_submodules=1, 67 | ... num_attention_heads=12, 68 | ... num_layers=12, 69 | ... encoder_hidden_size=768, 70 | ... ) 71 | >>> prefix_encoder = PrefixEncoder(config) 72 | ``` 73 | 74 | **Attributes**: 75 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prefix encoder. 76 | - **transform** (`torch.nn.Sequential`) -- The two-layer MLP to transform the prefix embeddings if 77 | `prefix_projection` is `True`. 78 | - **prefix_projection** (`bool`) -- Whether to project the prefix embeddings. 79 | 80 | Input shape: (`batch_size`, `num_virtual_tokens`) 81 | 82 | Output shape: (`batch_size`, `num_virtual_tokens`, `2*layers*hidden`) 83 | """ 84 | 85 | def __init__(self, config): 86 | super().__init__() 87 | self.prefix_projection = config.prefix_projection 88 | token_dim = config.token_dim 89 | num_layers = config.num_layers 90 | encoder_hidden_size = config.encoder_hidden_size 91 | num_virtual_tokens = config.num_virtual_tokens 92 | if self.prefix_projection and not config.inference_mode: 93 | # Use a two-layer MLP to encode the prefix 94 | self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim) 95 | self.transform = torch.nn.Sequential( 96 | torch.nn.Linear(token_dim, encoder_hidden_size), 97 | torch.nn.Tanh(), 98 | torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim), 99 | ) 100 | else: 101 | self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim) 102 | 103 | def forward(self, prefix: torch.Tensor): 104 | if self.prefix_projection: 105 | prefix_tokens = self.embedding(prefix) 106 | past_key_values = self.transform(prefix_tokens) 107 | else: 108 | past_key_values = self.embedding(prefix) 109 | return past_key_values 110 | -------------------------------------------------------------------------------- /peft/tuners/prompt_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import enum 17 | import math 18 | from dataclasses import dataclass, field 19 | from typing import Optional, Union 20 | 21 | import torch 22 | 23 | from ..utils import PeftType, PromptLearningConfig 24 | 25 | 26 | class PromptTuningInit(str, enum.Enum): 27 | TEXT = "TEXT" 28 | RANDOM = "RANDOM" 29 | 30 | 31 | @dataclass 32 | class PromptTuningConfig(PromptLearningConfig): 33 | """ 34 | This is the configuration class to store the configuration of a [`PromptEmbedding`]. 35 | 36 | Args: 37 | prompt_tuning_init (Union[[`PromptTuningInit`], `str`]): The initialization of the prompt embedding. 38 | prompt_tuning_init_text (`str`, *optional*): 39 | The text to initialize the prompt embedding. Only used if `prompt_tuning_init` is `TEXT`. 40 | tokenizer_name_or_path (`str`, *optional*): 41 | The name or path of the tokenizer. Only used if `prompt_tuning_init` is `TEXT`. 42 | """ 43 | 44 | prompt_tuning_init: Union[PromptTuningInit, str] = field( 45 | default=PromptTuningInit.RANDOM, 46 | metadata={"help": "How to initialize the prompt tuning parameters"}, 47 | ) 48 | prompt_tuning_init_text: Optional[str] = field( 49 | default=None, 50 | metadata={ 51 | "help": "The text to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" 52 | }, 53 | ) 54 | tokenizer_name_or_path: Optional[str] = field( 55 | default=None, 56 | metadata={ 57 | "help": "The tokenizer to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" 58 | }, 59 | ) 60 | 61 | def __post_init__(self): 62 | self.peft_type = PeftType.PROMPT_TUNING 63 | 64 | 65 | class PromptEmbedding(torch.nn.Module): 66 | """ 67 | The model to encode virtual tokens into prompt embeddings. 68 | 69 | Args: 70 | config ([`PromptTuningConfig`]): The configuration of the prompt embedding. 71 | word_embeddings (`torch.nn.Module`): The word embeddings of the base transformer model. 72 | 73 | **Attributes**: 74 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt embedding. 75 | 76 | Example: 77 | 78 | ```py 79 | >>> from peft import PromptEmbedding, PromptTuningConfig 80 | 81 | >>> config = PromptTuningConfig( 82 | ... peft_type="PROMPT_TUNING", 83 | ... task_type="SEQ_2_SEQ_LM", 84 | ... num_virtual_tokens=20, 85 | ... token_dim=768, 86 | ... num_transformer_submodules=1, 87 | ... num_attention_heads=12, 88 | ... num_layers=12, 89 | ... prompt_tuning_init="TEXT", 90 | ... prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral", 91 | ... tokenizer_name_or_path="t5-base", 92 | ... ) 93 | 94 | >>> # t5_model.shared is the word embeddings of the base model 95 | >>> prompt_embedding = PromptEmbedding(config, t5_model.shared) 96 | ``` 97 | 98 | Input Shape: (`batch_size`, `total_virtual_tokens`) 99 | 100 | Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`) 101 | """ 102 | 103 | def __init__(self, config, word_embeddings): 104 | super().__init__() 105 | 106 | total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules 107 | self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim) 108 | if config.prompt_tuning_init == PromptTuningInit.TEXT: 109 | from transformers import AutoTokenizer 110 | 111 | tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path) 112 | init_text = config.prompt_tuning_init_text 113 | init_token_ids = tokenizer(init_text)["input_ids"] 114 | # Trim or iterate until num_text_tokens matches total_virtual_tokens 115 | num_text_tokens = len(init_token_ids) 116 | if num_text_tokens > total_virtual_tokens: 117 | init_token_ids = init_token_ids[:total_virtual_tokens] 118 | elif num_text_tokens < total_virtual_tokens: 119 | num_reps = math.ceil(total_virtual_tokens / num_text_tokens) 120 | init_token_ids = init_token_ids * num_reps 121 | init_token_ids = init_token_ids[:total_virtual_tokens] 122 | 123 | word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone() 124 | word_embedding_weights = word_embedding_weights.to(torch.float32) 125 | self.embedding.weight = torch.nn.Parameter(word_embedding_weights) 126 | 127 | def forward(self, indices): 128 | # Just get embeddings 129 | prompt_embeddings = self.embedding(indices) 130 | return prompt_embeddings 131 | -------------------------------------------------------------------------------- /peft/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType 21 | from .other import ( 22 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, 23 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, 24 | TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, 25 | CONFIG_NAME, 26 | WEIGHTS_NAME, 27 | _set_trainable, 28 | bloom_model_postprocess_past_key_value, 29 | prepare_model_for_int8_training, 30 | shift_tokens_right, 31 | transpose, 32 | _get_submodules, 33 | _set_adapter, 34 | _freeze_adapter, 35 | ModulesToSaveWrapper, 36 | ) 37 | from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict 38 | -------------------------------------------------------------------------------- /peft/utils/config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import enum 16 | import json 17 | import os 18 | from dataclasses import asdict, dataclass, field 19 | from typing import Optional, Union 20 | 21 | from huggingface_hub import hf_hub_download 22 | from transformers.utils import PushToHubMixin 23 | 24 | from .other import CONFIG_NAME 25 | 26 | 27 | class PeftType(str, enum.Enum): 28 | PROMPT_TUNING = "PROMPT_TUNING" 29 | P_TUNING = "P_TUNING" 30 | PREFIX_TUNING = "PREFIX_TUNING" 31 | LORA = "LORA" 32 | ADALORA = "ADALORA" 33 | ADAPTION_PROMPT = "ADAPTION_PROMPT" 34 | 35 | 36 | class TaskType(str, enum.Enum): 37 | SEQ_CLS = "SEQ_CLS" 38 | SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM" 39 | CAUSAL_LM = "CAUSAL_LM" 40 | TOKEN_CLS = "TOKEN_CLS" 41 | 42 | 43 | @dataclass 44 | class PeftConfigMixin(PushToHubMixin): 45 | r""" 46 | This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all 47 | PEFT adapter models. This class inherits from [`~transformers.utils.PushToHubMixin`] which contains the methods to 48 | push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a 49 | directory. The method `from_pretrained` will load the configuration of your adapter model from a directory. 50 | 51 | Args: 52 | peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use. 53 | """ 54 | peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."}) 55 | 56 | @property 57 | def __dict__(self): 58 | return asdict(self) 59 | 60 | def to_dict(self): 61 | return self.__dict__ 62 | 63 | def save_pretrained(self, save_directory, **kwargs): 64 | r""" 65 | This method saves the configuration of your adapter model in a directory. 66 | 67 | Args: 68 | save_directory (`str`): 69 | The directory where the configuration will be saved. 70 | kwargs (additional keyword arguments, *optional*): 71 | Additional keyword arguments passed along to the [`~transformers.utils.PushToHubMixin.push_to_hub`] 72 | method. 73 | """ 74 | if os.path.isfile(save_directory): 75 | raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") 76 | 77 | os.makedirs(save_directory, exist_ok=True) 78 | 79 | output_dict = self.__dict__ 80 | output_path = os.path.join(save_directory, CONFIG_NAME) 81 | 82 | # save it 83 | with open(output_path, "w") as writer: 84 | writer.write(json.dumps(output_dict, indent=2, sort_keys=True)) 85 | 86 | @classmethod 87 | def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs): 88 | r""" 89 | This method loads the configuration of your adapter model from a directory. 90 | 91 | Args: 92 | pretrained_model_name_or_path (`str`): 93 | The directory or the Hub repository id where the configuration is saved. 94 | kwargs (additional keyword arguments, *optional*): 95 | Additional keyword arguments passed along to the child class initialization. 96 | """ 97 | path = ( 98 | os.path.join(pretrained_model_name_or_path, subfolder) 99 | if subfolder is not None 100 | else pretrained_model_name_or_path 101 | ) 102 | if os.path.isfile(os.path.join(path, CONFIG_NAME)): 103 | config_file = os.path.join(path, CONFIG_NAME) 104 | else: 105 | try: 106 | config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder) 107 | except Exception: 108 | raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'") 109 | 110 | loaded_attributes = cls.from_json_file(config_file) 111 | 112 | config = cls(**kwargs) 113 | 114 | for key, value in loaded_attributes.items(): 115 | if hasattr(config, key): 116 | setattr(config, key, value) 117 | 118 | return config 119 | 120 | @classmethod 121 | def from_json_file(cls, path_json_file, **kwargs): 122 | r""" 123 | Loads a configuration file from a json file. 124 | 125 | Args: 126 | path_json_file (`str`): 127 | The path to the json file. 128 | """ 129 | with open(path_json_file, "r") as file: 130 | json_object = json.load(file) 131 | 132 | return json_object 133 | 134 | 135 | @dataclass 136 | class PeftConfig(PeftConfigMixin): 137 | """ 138 | This is the base configuration class to store the configuration of a [`PeftModel`]. 139 | 140 | Args: 141 | peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use. 142 | task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform. 143 | inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode. 144 | """ 145 | 146 | base_model_name_or_path: str = field(default=None, metadata={"help": "The name of the base model to use."}) 147 | peft_type: Union[str, PeftType] = field(default=None, metadata={"help": "Peft type"}) 148 | task_type: Union[str, TaskType] = field(default=None, metadata={"help": "Task type"}) 149 | inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"}) 150 | 151 | 152 | @dataclass 153 | class PromptLearningConfig(PeftConfig): 154 | """ 155 | This is the base configuration class to store the configuration of [`PrefixTuning`], [`PromptEncoder`], or 156 | [`PromptTuning`]. 157 | 158 | Args: 159 | num_virtual_tokens (`int`): The number of virtual tokens to use. 160 | token_dim (`int`): The hidden embedding dimension of the base transformer model. 161 | num_transformer_submodules (`int`): The number of transformer submodules in the base transformer model. 162 | num_attention_heads (`int`): The number of attention heads in the base transformer model. 163 | num_layers (`int`): The number of layers in the base transformer model. 164 | """ 165 | 166 | num_virtual_tokens: int = field(default=None, metadata={"help": "Number of virtual tokens"}) 167 | token_dim: int = field( 168 | default=None, metadata={"help": "The hidden embedding dimension of the base transformer model"} 169 | ) 170 | num_transformer_submodules: Optional[int] = field( 171 | default=None, metadata={"help": "Number of transformer submodules"} 172 | ) 173 | num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"}) 174 | num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"}) 175 | -------------------------------------------------------------------------------- /peft/utils/other.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import copy 17 | 18 | import torch 19 | 20 | 21 | # needed for prefix-tuning of bloom model 22 | def bloom_model_postprocess_past_key_value(past_key_values): 23 | past_key_values = torch.cat(past_key_values) 24 | total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape 25 | keys = past_key_values[: total_layers // 2] 26 | keys = keys.transpose(2, 3).reshape( 27 | total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens 28 | ) 29 | values = past_key_values[total_layers // 2 :] 30 | values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim) 31 | 32 | return tuple(zip(keys, values)) 33 | 34 | 35 | def prepare_model_for_int8_training( 36 | model, output_embedding_layer_name="lm_head", use_gradient_checkpointing=True, layer_norm_names=["layer_norm"] 37 | ): 38 | r""" 39 | This method wraps the entire protocol for preparing a model before running a training. This includes: 40 | 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm 41 | head to fp32 42 | 43 | Args: 44 | model, (`transformers.PreTrainedModel`): 45 | The loaded model from `transformers` 46 | """ 47 | loaded_in_8bit = getattr(model, "is_loaded_in_8bit", False) 48 | 49 | for name, param in model.named_parameters(): 50 | # freeze base model's layers 51 | param.requires_grad = False 52 | 53 | if loaded_in_8bit: 54 | # cast layer norm in fp32 for stability for 8bit models 55 | if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): 56 | param.data = param.data.to(torch.float32) 57 | 58 | if loaded_in_8bit and use_gradient_checkpointing: 59 | # For backward compatibility 60 | if hasattr(model, "enable_input_require_grads"): 61 | model.enable_input_require_grads() 62 | else: 63 | 64 | def make_inputs_require_grad(module, input, output): 65 | output.requires_grad_(True) 66 | 67 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 68 | 69 | # enable gradient checkpointing for memory efficiency 70 | model.gradient_checkpointing_enable() 71 | 72 | if hasattr(model, output_embedding_layer_name): 73 | output_embedding_layer = getattr(model, output_embedding_layer_name) 74 | input_dtype = output_embedding_layer.weight.dtype 75 | 76 | class CastOutputToFloat(torch.nn.Sequential): 77 | r""" 78 | Manually cast to the expected dtype of the lm_head as sometimes there is a final layer norm that is casted 79 | in fp32 80 | 81 | """ 82 | 83 | def forward(self, x): 84 | return super().forward(x.to(input_dtype)).to(torch.float32) 85 | 86 | setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer)) 87 | 88 | return model 89 | 90 | 91 | # copied from transformers.models.bart.modeling_bart 92 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 93 | """ 94 | Shift input ids one token to the right. 95 | 96 | Args: 97 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids 98 | pad_token_id (`int`): The id of the `padding` token. 99 | decoder_start_token_id (`int`): The id of the `start` token. 100 | """ 101 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 102 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 103 | shifted_input_ids[:, 0] = decoder_start_token_id 104 | 105 | if pad_token_id is None: 106 | raise ValueError("self.model.config.pad_token_id has to be defined.") 107 | # replace possible -100 values in labels by `pad_token_id` 108 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 109 | 110 | return shifted_input_ids 111 | 112 | 113 | class ModulesToSaveWrapper(torch.nn.Module): 114 | def __init__(self, module_to_save, adapter_name): 115 | super().__init__() 116 | self.original_module = module_to_save 117 | self.modules_to_save = torch.nn.ModuleDict({}) 118 | self.update(adapter_name) 119 | self.active_adapter = adapter_name 120 | 121 | def update(self, adapter_name): 122 | self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)})) 123 | 124 | def forward(self, *args, **kwargs): 125 | if self.active_adapter not in self.modules_to_save: 126 | return self.original_module(*args, **kwargs) 127 | return self.modules_to_save[self.active_adapter](*args, **kwargs) 128 | 129 | 130 | def _get_submodules(model, key): 131 | parent = model.get_submodule(".".join(key.split(".")[:-1])) 132 | target_name = key.split(".")[-1] 133 | target = model.get_submodule(key) 134 | return parent, target, target_name 135 | 136 | 137 | def _freeze_adapter(model, adapter_name): 138 | for n, p in model.named_parameters(): 139 | if adapter_name in n: 140 | p.requires_grad = False 141 | 142 | 143 | def _set_trainable(model, adapter_name): 144 | key_list = [key for key, _ in model.named_modules()] 145 | for key in key_list: 146 | target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save) 147 | if target_module_found: 148 | parent, target, target_name = _get_submodules(model, key) 149 | if isinstance(target, ModulesToSaveWrapper): 150 | target.update(adapter_name) 151 | else: 152 | for param in target.parameters(): 153 | param.requires_grad = True 154 | setattr(parent, target_name, ModulesToSaveWrapper(target, adapter_name)) 155 | 156 | 157 | def _set_adapter(model, adapter_name): 158 | for module in model.modules(): 159 | if isinstance(module, ModulesToSaveWrapper): 160 | module.active_adapter = adapter_name 161 | 162 | 163 | def fsdp_auto_wrap_policy(model): 164 | import functools 165 | import os 166 | 167 | from accelerate import FullyShardedDataParallelPlugin 168 | from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy 169 | 170 | from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder 171 | 172 | def lambda_policy_fn(module): 173 | if ( 174 | len(list(module.named_children())) == 0 175 | and getattr(module, "weight", None) is not None 176 | and module.weight.requires_grad 177 | ): 178 | return True 179 | return False 180 | 181 | lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) 182 | transformer_wrap_policy = functools.partial( 183 | transformer_auto_wrap_policy, 184 | transformer_layer_cls=( 185 | PrefixEncoder, 186 | PromptEncoder, 187 | PromptEmbedding, 188 | FullyShardedDataParallelPlugin.get_module_class_from_name( 189 | model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "") 190 | ), 191 | ), 192 | ) 193 | 194 | auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) 195 | return auto_wrap_policy 196 | 197 | 198 | def transpose(weight, fan_in_fan_out): 199 | return weight.T if fan_in_fan_out else weight 200 | 201 | 202 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = { 203 | "t5": ["q", "v"], 204 | "mt5": ["q", "v"], 205 | "bart": ["q_proj", "v_proj"], 206 | "gpt2": ["c_attn"], 207 | "bloom": ["query_key_value"], 208 | "blip-2": ["q", "v", "q_proj", "v_proj"], 209 | "opt": ["q_proj", "v_proj"], 210 | "gptj": ["q_proj", "v_proj"], 211 | "gpt_neox": ["query_key_value"], 212 | "gpt_neo": ["q_proj", "v_proj"], 213 | "bert": ["query", "value"], 214 | "roberta": ["query", "value"], 215 | "xlm-roberta": ["query", "value"], 216 | "electra": ["query", "value"], 217 | "deberta-v2": ["query_proj", "value_proj"], 218 | "deberta": ["in_proj"], 219 | "layoutlm": ["query", "value"], 220 | "llama": ["q_proj", "v_proj"], 221 | "chatglm": ["query_key_value"], 222 | } 223 | 224 | TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = { 225 | "t5": ["q", "k", "v", "o", "wi", "wo"], 226 | "mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"], 227 | "bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], 228 | # "gpt2": ["c_attn"], 229 | # "bloom": ["query_key_value"], 230 | "opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], 231 | # "gptj": ["q_proj", "v_proj"], 232 | # "gpt_neox": ["query_key_value"], 233 | # "gpt_neo": ["q_proj", "v_proj"], 234 | # "bert": ["query", "value"], 235 | "roberta": ["query", "key", "value", "dense"], 236 | # "xlm-roberta": ["query", "value"], 237 | # "electra": ["query", "value"], 238 | "deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"], 239 | # "deberta": ["in_proj"], 240 | # "layoutlm": ["query", "value"], 241 | } 242 | 243 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = { 244 | "bloom": bloom_model_postprocess_past_key_value, 245 | } 246 | 247 | WEIGHTS_NAME = "adapter_model.bin" 248 | CONFIG_NAME = "adapter_config.json" 249 | -------------------------------------------------------------------------------- /peft/utils/save_and_load.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .config import PeftType, PromptLearningConfig 17 | 18 | 19 | def get_peft_model_state_dict(model, state_dict=None, adapter_name="default"): 20 | """ 21 | Get the state dict of the Peft model. 22 | 23 | Args: 24 | model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP, 25 | the model should be the underlying model/unwrapped model (i.e. model.module). 26 | state_dict (`dict`, *optional*, defaults to `None`): 27 | The state dict of the model. If not provided, the state dict of the model 28 | will be used. 29 | """ 30 | config = model.peft_config[adapter_name] 31 | if state_dict is None: 32 | state_dict = model.state_dict() 33 | if config.peft_type in (PeftType.LORA, PeftType.ADALORA): 34 | # to_return = lora_state_dict(model, bias=model.peft_config.bias) 35 | # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py` 36 | # to be used directly with the state dict which is necessary when using DeepSpeed or FSDP 37 | bias = config.bias 38 | if bias == "none": 39 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k} 40 | elif bias == "all": 41 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k} 42 | elif bias == "lora_only": 43 | to_return = {} 44 | for k in state_dict: 45 | if "lora_" in k: 46 | to_return[k] = state_dict[k] 47 | bias_name = k.split("lora_")[0] + "bias" 48 | if bias_name in state_dict: 49 | to_return[bias_name] = state_dict[bias_name] 50 | else: 51 | raise NotImplementedError 52 | to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))} 53 | if config.peft_type == PeftType.ADALORA: 54 | rank_pattern = config.rank_pattern 55 | if rank_pattern is not None: 56 | rank_pattern = {k.replace(f".{adapter_name}", ""): v for k, v in rank_pattern.items()} 57 | config.rank_pattern = rank_pattern 58 | to_return = model.resize_state_dict_by_rank_pattern(rank_pattern, to_return, adapter_name) 59 | 60 | elif config.peft_type == PeftType.ADAPTION_PROMPT: 61 | to_return = {k: state_dict[k] for k in state_dict if k.split(".")[-1].startswith("adaption_")} 62 | elif isinstance(config, PromptLearningConfig): 63 | to_return = {} 64 | if config.inference_mode: 65 | prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight 66 | else: 67 | prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name) 68 | to_return["prompt_embeddings"] = prompt_embeddings 69 | else: 70 | raise NotImplementedError 71 | if model.modules_to_save is not None: 72 | for key, value in state_dict.items(): 73 | if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save): 74 | to_return[key.replace("modules_to_save.", "")] = value 75 | 76 | to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()} 77 | return to_return 78 | 79 | 80 | def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="default"): 81 | """ 82 | Set the state dict of the Peft model. 83 | 84 | Args: 85 | model ([`PeftModel`]): The Peft model. 86 | peft_model_state_dict (`dict`): The state dict of the Peft model. 87 | """ 88 | config = model.peft_config[adapter_name] 89 | state_dict = {} 90 | if model.modules_to_save is not None: 91 | for key, value in peft_model_state_dict.items(): 92 | if any(module_name in key for module_name in model.modules_to_save): 93 | for module_name in model.modules_to_save: 94 | if module_name in key: 95 | key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}") 96 | break 97 | state_dict[key] = value 98 | else: 99 | state_dict = peft_model_state_dict 100 | 101 | if config.peft_type in (PeftType.LORA, PeftType.ADALORA): 102 | peft_model_state_dict = {} 103 | for k, v in state_dict.items(): 104 | if "lora_" in k: 105 | suffix = k.split("lora_")[1] 106 | if "." in suffix: 107 | suffix_to_replace = ".".join(suffix.split(".")[1:]) 108 | k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}") 109 | else: 110 | k = f"{k}.{adapter_name}" 111 | peft_model_state_dict[k] = v 112 | else: 113 | peft_model_state_dict[k] = v 114 | if config.peft_type == PeftType.ADALORA: 115 | rank_pattern = config.rank_pattern 116 | if rank_pattern is not None: 117 | model.resize_modules_by_rank_pattern(rank_pattern, adapter_name) 118 | elif isinstance(config, PromptLearningConfig) or config.peft_type == PeftType.ADAPTION_PROMPT: 119 | peft_model_state_dict = state_dict 120 | else: 121 | raise NotImplementedError 122 | 123 | model.load_state_dict(peft_model_state_dict, strict=False) 124 | if isinstance(config, PromptLearningConfig): 125 | model.prompt_encoder[adapter_name].embedding.load_state_dict( 126 | {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True 127 | ) 128 | -------------------------------------------------------------------------------- /pics/dingding_groups.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michael-wzhu/PromptCBLUE/b0753a61a7c1f4e1ae171109f8a59037ff0a5543/pics/dingding_groups.jpg -------------------------------------------------------------------------------- /pics/promptCBLUE_banner_v0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michael-wzhu/PromptCBLUE/b0753a61a7c1f4e1ae171109f8a59037ff0a5543/pics/promptCBLUE_banner_v0.png -------------------------------------------------------------------------------- /pics/promptCBLUE_en_banner_v0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michael-wzhu/PromptCBLUE/b0753a61a7c1f4e1ae171109f8a59037ff0a5543/pics/promptCBLUE_en_banner_v0.png -------------------------------------------------------------------------------- /pics/wechat_qrcode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michael-wzhu/PromptCBLUE/b0753a61a7c1f4e1ae171109f8a59037ff0a5543/pics/wechat_qrcode.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.18.0 2 | huggingface_hub==0.14.1 3 | jieba==0.42.1 4 | nltk==3.8.1 5 | numpy==1.24.3 6 | rouge_chinese==1.0.3 7 | sentencepiece==0.1.99 8 | torch==1.12.0+cu113 9 | tqdm==4.65.0 10 | transformers==4.28.1 11 | -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## 依赖 4 | 5 | 我们根据我们实验所用的环境生成了[requirements.txt](./requirements.txt)。参赛者可以自行配置版本更高的环境。 6 | 7 | ## 模型下载 8 | 9 | 先将ChatGLM权重下载到本地。通过下面的命令, 10 | ```bash 11 | python src/download_checkpoints.py 12 | ``` 13 | 模型会存放在类似于`./models--THUDM--chatglm-6b/snapshots/a8ede826cf1b62bd3c78bdfb3625c7c5d2048fbd`的路径中,加载模型时候就是采用这个路径。 14 | 15 | 16 | ## 数据 17 | 18 | 请前往[PromptCBLUE通用赛道评测网站](https://tianchi.aliyun.com/competition/entrance/532085/introduction)或者[PromptCBLUE通用赛道评测网站](https://tianchi.aliyun.com/competition/entrance/532084/introduction)下载训练集,验证集以及测试集A或者测试集B。这些数据放置在自己指定的文件夹中,如"datasets/PromptCBLUE/toy_examples"。 19 | 20 | 21 | ## ChatGLM-6B + P-tuning 方法 22 | 23 | 这部分代码借鉴了ChatGLM-6B官方的p-tuning代码。 24 | 25 | ### 训练 26 | 27 | ```bash 28 | ./src/ft_chatglm_ptuning/train.sh 29 | 30 | ``` 31 | 32 | 33 | ### 预测(生成回复) 34 | 35 | ```bash 36 | ./src/ft_chatglm_ptuning/evaluate.sh 37 | 38 | ``` 39 | 40 | 41 | 42 | ## ChatGLM-6B + LoRA方法微调 43 | 44 | 这部分代码实现借助了[PEFT项目](https://github.com/huggingface/peft)。注意PEFT直接用pip安装的话,需要torch==2.0以上,同时cuda也需要高版本。如果大家不想更新torch环境,可以直接拷贝他的核心代码,放在自己的代码库里面,如[./src/ft_chatglm_lora/peft](./src/ft_chatglm_lora/peft),这样就可以在更低版本的torch环境下使用。 45 | 46 | 注意ChatGLM-6B采用了query,key,value矩阵参数共享,所以LoRA作用的模块名称是与其他模型不同的。我们这里要LoRA作用于`query_key_value,dense,dense_h_to_4h,dense_4h_to_h`这些模块。 47 | 48 | 49 | ### 训练 50 | 51 | ```bash 52 | 53 | src/ft_chatglm_lora/train.sh 54 | 55 | ``` 56 | 57 | ### 预测(生成回复) 58 | 59 | 预测时,可以根据自身判断,选择调整`src/ft_chatglm_lora/main.py`代码的445行到455行的模型生成设置,比如`num_beams`, `do_sample`等。我们现在设置`do_sample=False`和`num_beams=1`,即采用贪心解码。自然地,设置更大的`num_beams`相应的可以提升生成效果,不过也会带来显存压力。 60 | 61 | 同时,大家根据卡的显存,设置下面脚本中的`per_device_eval_batch_size`取值。我们目前的生成设置和脚本设置的入参,推理需要25G显存,在V100 (40G)上单卡5个小时左右跑完测试集。 62 | 63 | ```bash 64 | ./src/ft_chatglm_lora/evaluate.sh 65 | 66 | ``` 67 | 68 | 预测效率提升有很多途径:包括模型量化,或者使用推理框架,如vLLM。 69 | 70 | 71 | 72 | ## LlaMA-7B + LoRA方法微调 73 | 74 | ### LlaMA模型准备 75 | 76 | 我们先要准备LlaMA模型底座,使得其可以在huggingface transformers框架下进行参数高效微调。准备工作主要有三步: 77 | 78 | #### LlaMA模型主干 79 | 80 | 获取LlaMA模型主干有几种途径: 81 | - 原版LLaMA模型: 在[LlaMA原项目地址](https://github.com/facebookresearch/llama)填写google form申请; 82 | - [LlaMA项目的一个PR](https://github.com/facebookresearch/llama/pull/73/files) 83 | - huggingface的model hub中已经人上传了模型: [decapoda-research/llama-7b-hf](https://huggingface.co/decapoda-research/llama-7b-hf) 84 | 85 | #### LlaMA模型权重转化 86 | 87 | 上一步骤的前两种方法需要将LlaMA模型权重转化为huggingface transformers的格式,详见[convert_llama_weights_to_hf](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py))。 88 | 89 | 90 | #### 融合Chinese-LlaMA-Alpaca 91 | 92 | [Chinese-LlaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca/blob/main/README_EN.md)项目提供了使得LlaMA模型更适应于中文场景的lora权重和经过继续预训练的embedding权重。我们采用其脚本将其权重合并到模型主干中: 93 | 94 | ```bash 95 | python src/ft_llama_lora/merge_llama_with_chinese_lora.py \ 96 | --base_model decapoda-research/llama-7b-hf \ 97 | --lora_model ziqingyang/chinese-llama-plus-lora-7b,ziqingyang/chinese-alpaca-plus-lora-7b \ 98 | --output_type huggingface \ 99 | --output_dir ./resources/chinese-llama-alpaca-plus-lora-7b 100 | 101 | ``` 102 | 103 | 注意上述命令中我们合并了[Chinese-LlaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca)的两个lora权重,第一个权重是做了大规模中文语料预训练,第二个权重则是进一步做了基于self-instruct的中文指令微调。两者合并可以得到更会说中国话的LlaMA模型。 104 | 105 | 注意存储合并参数后的llama模型时,尽量模型文件shards多分一些,比如`max_shard_size=2GB`, 这样加载起来也会快一些。 106 | 107 | 108 | ### 训练 109 | 110 | ```bash 111 | 112 | CUDA_VISIBLE_DEVICES="2,3" ./my_scripts/promptcblue_fft/run_train.sh 113 | 114 | ``` 115 | 116 | ### 预测(生成回复) 117 | 118 | 预测时,我们采用[vllm项目](https://github.com/vllm-project/vllm)对模型进行serving.同时这部分代码参照了[KnowLM项目](https://github.com/zjunlp/KnowLM/tree/main/inference) 119 | 120 | 在使用vllm时,我们首先要把训练得到的lora参数与LlaMA主干进行合并 (假设我们采用训练第800步的lora权重): 121 | 122 | ```bash 123 | 124 | CUDA_VISIBLE_DEVICES="3" python src/ft_llama_lora/merge_llama_with_chinese_lora.py \ 125 | --base_model ./resources/chinese-llama-plus-lora-7b \ 126 | --lora_model ./experiments/output/promptcblue-llama-7b-pt-v0/checkpoint-800 \ 127 | --output_type huggingface \ 128 | --output_dir ./experiments/output/promptcblue-llama-7b-pt-v0/checkpoint-800-merge 129 | 130 | ``` 131 | 132 | 然后采用下面的命令启动模型服务。注意,我们修改了`src/ft_llama_lora/vllm_serving/llm_engine.py`第148行的`gpu_memory_utilization`参数取值,大家可以根据显卡情况修改。 133 | 134 | ```bash 135 | CUDA_VISIBLE_DEVICES="3" python src/ft_llama_lora/vllm_serving/launch_vllm.py \ 136 | --port 8000 \ 137 | --model ./experiments/output/promptcblue-llama-7b-pt-v0/checkpoint-800-merge \ 138 | --use-np-weights \ 139 | --max-num-batched-tokens 4096 \ 140 | --dtype half \ 141 | --tensor-parallel-size 1 142 | 143 | ``` 144 | 145 | 我们在生成的时候,不会传入有效的`parameters`字段,所以采样参数会使用`src/ft_llama_lora/vllm_serving/launch_vllm.py`的63行处`SamplingParams`的默认值。大家可以根据需求修改。vllm服务起好之后,我们可以通过下面的例子进行服务调用,从而进行测试集预测: 146 | 147 | ```bash 148 | python src/ft_llama_lora/vllm_serving/web_service_test.py 149 | 150 | ``` 151 | 152 | 通过vllm部署模型,我们测试下来预计加速2.5倍左右。 153 | 154 | 155 | 156 | 157 | ## Contributors 158 | 159 | - [michael-wzhu](https://github.com/michael-wzhu) 160 | - [boom-R123](https://github.com/boom-R123) 161 | -------------------------------------------------------------------------------- /src/data/CBLUE任务改造说明与举例.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ### CBLUE任务改造 4 | 5 | 我们采用94个指令微调模板,对CBLUE基准中的各个任务进行。经过改造后,医疗文本NLP数据集都将转化为如下格式。input字段是模型的输入,target字段是模型的输出,type是原任务类型(不作为模型输入),answer_choices字段是选项,只有分类、术语标准化、推理类任务上该字段才会有意义。 6 | 7 | ```json 8 | { 9 | "input": str, 10 | "target": str, 11 | "type": str, 12 | "answer_choices": str, 13 | "sample_id": str, 14 | } 15 | ``` 16 | 17 | 为了将CBLUE中的各种不同任务适配为符合LLM的输入输出格式,我们对CBLUE各个数据集进行了相应的改造。 18 | 19 | 20 | #### CMeEE任务 21 | 22 | 本任务原本是标准的医学文本NER任务,选手需要给出医学实体mention在待抽取文本中的具体span位置。在PromptCBLUE中,本任务被改造为:根据指定的实体类型,生成实体mention。在评分时,我们只考虑实体mention及其类型标签,不再考虑span位置信息。而且,特别注意的是,为了考察模型的指令理解与服从能力(instruction following),模型只能生成指令中指定的实体类型,而不能生成其他类型的实体。样例如下: 23 | 24 | ```json 25 | { 26 | "input": "医学实体识别:\n外周血白细胞计数常明显升高,伴核左移。\n实体选项:疾病,医学检验项目,医院科室,身体部位,微生物类,临床表现,药物\n答:", 27 | "target": "上述句子中的实体包含:\n医学检验项目实体:外周血白细胞计数\n疾病实体:\n医院科室实体:\n药物实体:\n微生物类实体:", 28 | "answer_choices": ["疾病", "医学检验项目", "医院科室", "身体部位", "微生物类", "临床表现", "药物"], 29 | "task_type": "ner", 30 | "task_dataset": "CMeEE-V2", 31 | "sample_id": "train-134372" 32 | } 33 | ``` 34 | 35 | 上述样例中的target即为模型输出,而评测参与者需要根据自己的LLM输出进行解析,得到抽取结果。LLM输出的格式可以自己定义,也可以根据我们的样例来进行。 36 | 37 | 38 | #### CMeIE任务 39 | 40 | 本任务是三元组联合抽取任务。在PromptCBLUE中,我们将其定义为:在指定的关系类型下,抽取形成该关系的头尾实体mention。参赛者可以根据需要对本任务的指令/提示进行进一步拆解,以更好的完成任务。示例如下: 41 | 42 | ```json 43 | { 44 | "input": "找出句子中的具有临床表现,同义词关系类型的头尾实体对:\n妊娠期高血压@健康孕妇可以通过全身血管阻力 (systemic vascular resistance,SVR) 和心输出量下降得以平衡,而妊娠期高血压妇女SVR较低,心输出量往往会增加。", 45 | "target": "上述句子中临床表现关系的实体对如下:\n头实体:妊娠期高血压,尾实体:SVR较低;头实体:妊娠期高血压,尾实体:心输出量往往会增加;\n上述句子中同义词关系的实体对如下:", 46 | "answer_choices": ["临床表现", "同义词"], 47 | "task_type": "spo_generation", 48 | "task_dataset": "CMeIE", 49 | "sample_id": "train-67405" 50 | } 51 | ``` 52 | 53 | 54 | #### CHIP-CDEE任务 55 | 56 | 本任务是医学事件抽取,即给定一段现病史或者医学影像所见报告,要求从中抽取临床发现事件的四个属性: 解剖部位、主体词、描述词,以及发生状态。在PromptCBLUE下,该任务表现为根据给定文本,生成相应的事件信息描述。注意原任务的标注中,部分主体词是做了ICD10标准化的,我们将这一部分样本过滤掉了,也就是说我们目前主要是评估LLM的医学事件抽取能力。 57 | 58 | ```json 59 | 60 | { 61 | "input": "7月前患者给予亚砷酸氯化钠(伊泰达)注射液 10mg 静滴14天,6月前予以口服维甲酸 20mg bid*14天维持治疗,5月前行亚砷酸氯化钠(伊泰达)注射液 10mg 静滴14天维持化疗,3月余前复查骨髓检查示增生性骨髓象;fish:pml/rara(双色双融合)(15/17):未见异常;腰穿脑脊液未见异常细胞。现为维持化疗入院。\n问题:句子中的临床发现事件及其属性是什么?\n说明:临床发现事件由主体词,发生状态,描述词和解剖部位组成", 62 | "target": "上述句子中的临床发现事件如下:\n主体词:fish:pml/rara(双色双融合)(15/17)异常;发生状态:否定;描述词:;解剖部位:\n主体词:骨髓象;发生状态:;描述词:增生性;解剖部位:骨髓", 63 | "answer_choices": null, 64 | "task_type": "event_extraction", 65 | "task_dataset": "CHIP-CDEE", 66 | "sample_id": "train-17503" 67 | } 68 | 69 | ``` 70 | 71 | 72 | 73 | #### CHIP-CDN任务 74 | 75 | CHIP-CDN任务在CBLUE中的原型是:给定一诊断原词,要求给出其对应的诊断标准词,而诊断标准词是从ICD-10这一个4w+的标准词库中选择。由于我们不可能一次性将四万个词输入到LLM中(即使是GPT-4服务,最多只能输入32000个token),所以我们将CDN任务改造为:给定原词,从候选的若干个ICD-10诊断标准词中选择出匹配的词(可能有多个, 可能一个都没有)。而在实际业务中,我们可以结合对本地知识库的检索+LLM判断的方式,完整的预测诊断原词对应的诊断标准词。 76 | 77 | ```json 78 | 79 | { 80 | "input": "主动脉弓缩窄心功能低下\n归一化后的标准词是?\n实体选项:胫前动脉假性动脉瘤,主动脉缩窄,男性性腺功能低下,男性性腺功能低下,垂体功能低下,心功能不全\n说明:从候选的若干个ICD-10诊断标准词中选择出与原诊断描述匹配的词\n答:", 81 | "target": "主动脉缩窄,心功能不全", 82 | "answer_choices": ["胫前动脉假性动脉瘤", "主动脉缩窄", "男性性腺功能低下", "男性性腺功能低下", "垂体功能低下", "心功能不全"], 83 | "task_type": "normalization", 84 | "task_dataset": "CHIP-CDN", 85 | "sample_id": "train-17932" 86 | } 87 | 88 | 89 | 90 | ``` 91 | 92 | 93 | 94 | #### CHIP-CTC任务 95 | 96 | 本次评测任务的主要目标是针对临床试验筛选标准进行分类。注意到原数据集中一个占比较高的类型为"other"类,在PromptCBLUE下,我们将此类型名称取消。本任务转化为:根据给定的临床试验筛选标准类型,生成出给定句子的类型,或者回答"非上述类型"表明指令中给出的类型与句子不符合。 97 | 98 | 99 | ```json 100 | 101 | { 102 | "input": "8.过去3个月内有过眼内手术的患者;\n这句话是什么临床试验筛选标准类型?\n类型选项:成瘾行为,吸烟状况,性取向,残疾群体,读写能力,肿瘤进展,参与其它试验,疾病分期,能力,疾病,药物,诊断,教育情况,口腔相关,受体状态,健康群体,数据可及性,设备,献血,过敏耐受,特殊病人特征,睡眠,怀孕相关,研究者决定,器官组织状态,症状(患者感受),治疗或手术,护理,性别,种族,实验室检查,知情同意,饮食,年龄,居住情况,病例来源,酒精使用,体征(医生检测),锻炼,风险评估,预期寿命,伦理审查,依存性", 103 | "target": "治疗或手术", 104 | "answer_choices": ["成瘾行为", "吸烟状况", "性取向", "残疾群体", "读写能力", "肿瘤进展", "参与其它试验", "疾病分期", "能力", "疾病", "药物", "诊断", "教育情况", "口腔相关", "受体状态", "健康群体", "数据可及性", "设备", "献血", "过敏耐受", "特殊病人特征", "睡眠", "怀孕相关", "研究者决定", "器官组织状态", "症状(患者感受)", "治疗或手术", "护理", "性别", "种族", "实验室检查", "知情同意", "饮食", "年龄", "居住情况", "病例来源", "酒精使用", "体征(医生检测)", "锻炼", "风险评估", "预期寿命", "伦理审查", "依存性"], 105 | "task_type": "cls", 106 | "task_dataset": "CHIP-CTC", 107 | "sample_id": "train-19957" 108 | }, 109 | { 110 | "input": "判断临床试验筛选标准的类型:\n精神分裂症患者组(Schizophrenia,Sch):符合 DSM-IV偏执型分裂症诊断标准,年龄、性别、精神病病程与MAP组匹配,听力及视力正常、或矫正后处于正常范围。\n选项:残疾群体,口腔相关,诊断,能力,年龄,饮食,研究者决定,献血,参与其它试验,设备,护理,性别,症状(患者感受),依存性,睡眠", 111 | "target": "非上述类型", 112 | "answer_choices": ["残疾群体", "口腔相关", "诊断", "能力", "年龄", "饮食", "研究者决定", "献血", "参与其它试验", "设备", "护理", "性别", "症状(患者感受)", "依存性", "睡眠"], 113 | "task_type": "cls", 114 | "task_dataset": "CHIP-CTC", 115 | "sample_id": "train-63105" 116 | } 117 | 118 | ``` 119 | 120 | 121 | #### KUAKE-QIC任务 122 | 123 | 本任务是对医学场景的搜索问题的意图分类。注意到原数据集中一个占比较高的类型为"其他"类,在PromptCBLUE下,我们将此类型名称取消,并将本任务转化为:根据给定的医学搜索意图分类标签,生成出给定句子的类型,或者回答"非上述类型"表明prompt中指定的类型与句子不符合。 124 | 125 | ```json 126 | 127 | { 128 | "input": "确定检索词的类型:\n怎样备孕\n类型选项:病情诊断,功效作用,注意事项,治疗方案,后果表述,就医建议,医疗费用\n答:", 129 | "target": "非上述类型", 130 | "answer_choices": ["病情诊断", "功效作用", "注意事项", "治疗方案", "后果表述", "就医建议", "医疗费用"], 131 | "task_type": "cls", 132 | "task_dataset": "KUAKE-QIC", 133 | "sample_id": "train-29316" 134 | } 135 | { 136 | "input": "判断下面搜索词的意图:\n武汉传染性尖锐湿疣的治疗方法\n选项:指标解读,治疗方案,功效作用,注意事项,病情诊断,就医建议,疾病描述\n答:", 137 | "target": "治疗方案", 138 | "answer_choices": ["指标解读", "治疗方案", "功效作用", "注意事项", "病情诊断", "就医建议", "疾病描述"], 139 | "task_type": "cls", 140 | "task_dataset": "KUAKE-QIC", 141 | "sample_id": "train-17235" 142 | } 143 | 144 | 145 | ``` 146 | 147 | #### CHIP-STS任务 148 | 149 | 本任务旨在判断两个与疾病相关的问句是否表达相同语义。在PromptCBLUE下,本任务即为:根据输入的两个问句,输出模型对其语义是否相同的判断("是的", "不是")。示例如下: 150 | 151 | ```json 152 | 153 | { 154 | "input": "下面两个句子语义是“相同”或“不同”?\n“糖尿病的三多一少是什么”,“无限极的“灵芝皇”和“桑唐饮”能治好糖尿病吗?”。\n选项:相同,不同\n答:", 155 | "target": "不同", 156 | "answer_choices": ["相同", "不同"], 157 | "task_type": "cls", 158 | "task_dataset": "CHIP-STS", 159 | "sample_id": "train-54981" 160 | } 161 | 162 | ``` 163 | 164 | 165 | #### KUAKE-QTR任务 166 | 167 | 本任务旨在判断医疗搜索场景下搜索词Query主题和落地页标题(Title)主题是否一致及达到多大程度上的一致,输出匹配分数(0分, 1分, 2分, 3分)。在PromptCBLUE下,我们将任务转化为:要求LLM评估匹配程度,即输出描述匹配程度的标签词语。 168 | 169 | ```json 170 | 171 | { 172 | "input": "下面的搜索词和页面标签的意思有多相同?\n搜索词:宝宝三周了发烧不玩睡觉\n页面标签:孩子三周了手足口发烧一天就不烧了就是睡觉打搀\n选项:完全不匹配或者没有参考价值,很少匹配有一些参考价值,部分匹配,完全匹配\n答:", 173 | "target": "部分匹配", 174 | "answer_choices": ["完全不匹配或者没有参考价值", "很少匹配有一些参考价值", "部分匹配", "完全匹配"], 175 | "task_type": "matching", 176 | "task_dataset": "KUAKE-QTR", 177 | "sample_id": "train-67418" 178 | } 179 | 180 | ``` 181 | 182 | 183 | 184 | #### KUAKE-QQR任务 185 | 186 | 本任务旨在判断两个医疗方面得查询(query)的语义关系。在本任务下,LLM需要判断两个查询语句的语义是否完全相等,前者语义覆盖后者,或者后者语义包含前者,或者是毫无关系。示例如下: 187 | 188 | ```json 189 | 190 | { 191 | "input": "下面两个句子的语义关系是?\n“伤口涂什么药好得快”,“有伤口涂什么药”。\n选项: 完全一致,后者是前者的语义子集,后者是前者的语义父集,语义毫无关联", 192 | "target": "完全一致", 193 | "answer_choices": ["完全一致", "后者是前者的语义子集", "后者是前者的语义父集", "语义毫无关联"], 194 | "task_type": "matching", 195 | "task_dataset": "KUAKE-QQR", 196 | "sample_id": "train-88139" 197 | } 198 | 199 | 200 | ``` 201 | 202 | #### KUAKE-IR任务 203 | 204 | 本任务原本的设置为在100w规模的语料库中检索出与医疗query相关的文档(doc)。为使本任务符合LLM特点的任务,我们将其改造为:将一个query和一个doc输入句子中,判断内容是否匹配。举例如下: 205 | 206 | ```json 207 | 208 | { 209 | "input": "医疗搜索:鼻梁被撞鼻梁矫正手术\n以下回答内容是否能够回答搜索问题?\n回答内容:你好,你这中情况一般需要行鼻骨截骨整形及鼻中隔联合矫正,手术需要住院,大概需要10天左右的时间,费用在12000左右,我们医院不对医保,如果是要医保报销需要办转诊手续之后再凭相关单据回所在地报销。\n选项: 相关,不相关\n答:", 210 | "target": "相关", 211 | "answer_choices": ["相关", "不相关"], 212 | "task_type": "matching", 213 | "task_dataset": "KUAKE-IR", 214 | "sample_id": "train-801751" 215 | } 216 | 217 | 218 | ``` 219 | 220 | 221 | #### CHIP-MDCFNPC任务 222 | 223 | 本任务是根据一段患者与医生的对话交互,抽取对话中出现的临床发现实体,并最终判断患者在这个临床发现实体上的阴阳性。阴阳性定义为是患者主诉病情描述和医生诊断判别中的阴性和阳性,包括阴性、阳性、其他、不标注这四种。在PromptCBLUE下,我们采用阴阳性标签的描述语句作为模型的输出目标。示例如下: 224 | 225 | ```json 226 | 227 | { 228 | "input": "患者:月经来了还可吃乌鸡白凤丸和丹栀逍遥丸吗\n医生:请问类似症状出现多长时间?\n医生:你吃这药是治疗什么的\n患者:我前几个月去检查是游离子腺素增高,月经没来\n医生:什么高\n患者:甲状游离子腺素增高\n医生:把化验单给我看一下。\n患者:现在去检查正常值了\n医生:那你如果月经量多这些药就不吃了,如果月经量少就可以吃。\n患者:就是月经不调\n患者:甲状腺素药还有吃\n医生:是甲状腺功能低下吗?甲减吗?\n患者:我在马来西亚看不懂报告单\n医生:嗯嗯,只有甲状腺功能低下才需要吃甲状腺素。\n问题:上述问诊对话中临床发现有哪些?这些实体的阴阳性是?\n阴阳性选项:已有症状疾病或者假设未来可能发生的疾病等,未患有症状疾病,没有回答、不知道、回答不明确或者模棱两可不好推断,无实际意义的不标注或者和病人当前的状态独立不标注\n说明:临床发现是临床医学下,病人状态描述的概念集合", 229 | "target": "上述对话中临床发现实体以及其阴阳性判别如下:\n月经没来:已有症状疾病或者假设未来可能发生的疾病等\n游离子腺素增高:已有症状疾病或者假设未来可能发生的疾病等\n甲状游离子腺素增高:已有症状疾病或者假设未来可能发生的疾病等\n月经量少:无实际意义的不标注或者和病人当前的状态独立不标注\n月经量多:无实际意义的不标注或者和病人当前的状态独立不标注\n月经不调:已有症状疾病或者假设未来可能发生的疾病等\n甲减:没有回答、不知道、回答不明确或者模棱两可不好推断\n甲状腺功能低下:无实际意义的不标注或者和病人当前的状态独立不标注", 230 | "answer_choices": ["已有症状疾病或者假设未来可能发生的疾病等", "未患有症状疾病", "没有回答、不知道、回答不明确或者模棱两可不好推断", "无实际意义的不标注或者和病人当前的状态独立不标注"], 231 | "task_type": "attr_cls", 232 | "task_dataset": "CHIP-MDCFNPC", 233 | "sample_id": "train-982126" 234 | } 235 | 236 | 237 | ``` 238 | 239 | 240 | #### IMCS-V2-NER任务 241 | 242 | 本任务是从医患对话中抽取实体。本任务的指令形式与CMeEE任务相似。示例如下: 243 | 244 | ```json 245 | 246 | { 247 | "input": "下面对话中的医学检查检验,症状,医疗操作实体有哪些?\n宝贝也呕吐吗?\n答:", 248 | "target": "上述句子中的实体包含:\n医学检查检验实体:\n症状实体:呕吐\n医疗操作实体:", 249 | "answer_choices": ["医学检查检验", "症状", "医疗操作"], 250 | "task_type": "ner", 251 | "task_dataset": "IMCS-V2-NER", 252 | "sample_id": "train-63083" 253 | } 254 | 255 | ``` 256 | 257 | #### IMCS-V2-DAC任务 258 | 259 | 本任务旨在根据部分医患对话历史,识别当前对话文本的意图标签。为了将本任务改造的更适合于LLM,我们对本任务的意图进行了改写,使得意图标签名称更符合自然语言形式。本任务的指令形式与QIC任务基本一致。示例如下: 260 | 261 | ```json 262 | 263 | { 264 | "input": "确定这句话的意图:\n当时医生说我们单纯支气管炎也不喘就开的药\n类型选项:关于就医建议的解答,给出诊断,关于症状的回答,关于症状的询问,关于就医建议的提问,关于已有检查和治疗的回答,关于注意事项的提问,关于已有检查和治疗的提问,关于个人基本信息的询问,关于个人基本信息的回答,关于用药建议的解答,关于病因的询问,关于用药建议的提问,关于注意事项的解答,关于病因的回答", 265 | "target": "关于已有检查和治疗的回答", 266 | "answer_choices": ["关于就医建议的解答", "给出诊断", "关于症状的回答", "关于症状的询问", "关于就医建议的提问", "关于已有检查和治疗的回答", "关于注意事项的提问", "关于已有检查和治疗的提问", "关于个人基本信息的询问", "关于个人基本信息的回答", "关于用药建议的解答", "关于病因的询问", "关于用药建议的提问", "关于注意事项的解答", "关于病因的回答"], 267 | "task_type": "cls", 268 | "task_dataset": "IMCS-V2-DAC", 269 | "sample_id": "train-41955" 270 | } 271 | 272 | ``` 273 | 274 | 275 | #### IMCS-V2-SR任务 276 | 277 | 本任务包含多个子步骤:(1) 需要从一段医生与患者间的对话中抽取症状; (2)需要给出症状的标准词(从任务给定的400+的标准词库中选择); (3)需要根据对话历史,判断症状的阴阳性,即:患者是否患有该症状, 没有患有该症状, 或者无法根据上下文确定病人是否患有该症状。本任务的指令与CHIP-MDCFNPC任务类似。举例如下: 278 | 279 | ```json 280 | 281 | { 282 | "input": "找出当前对话中的症状,并判断阴阳性:\n对话历史:\n患者:没有怎么听啊\n医生:根据您的描述,宝宝咳嗽,嗓子吼,可能是气喘或喉鸣,考虑支气管炎的可能性较大\n当前对话:\n医生:需要带宝宝去医院儿科就诊,用听诊器听诊肺部,查血常规胸片等相关检查,排除肺炎,根据结果,给于控制感染,止咳化痰等对症治疗。\n症状阴阳性选项:没有患有该症状,患有该症状,无法根据上下文确定病人是否患有该症状\n答:", 283 | "target": "当前对话中的症状及其阴阳性判断为:\n肺炎:无法根据上下文确定病人是否患有该症状\n感染:患有该症状\n咳:患有该症状\n痰:患有该症状", 284 | "answer_choices": ["没有患有该症状", "患有该症状", "无法根据上下文确定病人是否患有该症状"], 285 | "task_type": "attr_cls", 286 | "task_dataset": "IMCS-V2-SR", 287 | "sample_id": "train-5434" 288 | } 289 | 290 | ``` 291 | 292 | 293 | #### IMCS-V2-MRG任务 294 | 295 | 本任务要求是从医患对话中自动生成对应的诊疗报告。标注数据中,统一的将诊疗报告分拆为6个章节(主诉, 现病史, 辅助检查, 既往史, 诊断, 建议)。在PromptCBLUE下,我们要求LLM按照这6个章节的顺序进行一次生成。在评测时,我们仍然会将模型输出拆分为这6个章节。样例如下: 296 | 297 | ```json 298 | 299 | { 300 | "input": "问诊对话历史:\n患者:宝宝刚满月,母乳喂养,最近两天时不时的会咳嗽一声,食欲和精神还行,只不过睡觉不是很安稳。家里面最近两天大人和宝宝的姐姐也有感冒,不知道宝宝是被传染了感冒还是怎么样,请问怎么治疗?\n医生:您好,我是您的辅诊医生,需要询问几个问题,才能更好的评估孩子情况,您还在吗?\n医生:宝宝体温正常吗?\n医生:还在吗?\n患者:在\n医生:您好\n医生:宝宝现在体温正常吗\n患者:体温正常\n医生:口吐泡泡吗\n患者:没有\n医生:嗓子哑吗\n患者:哭起来跟以前一样\n医生:好的\n患者:只不过鼻音重\n医生:还有其他症状吗\n患者:睡不安稳\n医生:出汗多吗?\n患者:不多\n医生:哭闹吗\n患者:比以前爱哭闹\n医生:大便什么样子\n患者:这个没注意,昨天一天没有大便,今天上午大便的,大便以后睡得安稳一些了,不过还是时不时咳嗽一声\n医生:有痰吗?\n患者:没有痰,干咳\n医生:嗯嗯\n医生:流鼻涕吗\n患者:没有\n医生:嗯嗯\n医生:因为宝宝比较小,最好带宝宝去公立医院儿科就诊,听一下肺部没有问题的话可以观察看看\n根据上述对话,给出诊疗报告\n说明:诊疗报告分为主诉, 现病史, 辅助检查, 既往史, 诊断, 建议这六个章节。\n答:", 301 | "target": "上述问诊对话的诊疗报告如下:\n主诉:阵发性咳嗽。\n现病史:患儿阵发性干咳两天。\n辅助检查:暂缺。\n既往史:不详。\n诊断:咳嗽待查。\n建议:儿科就诊,听诊肺部。", 302 | "answer_choices": null, 303 | "task_type": "report_generation", 304 | "task_dataset": "IMCS-V2-MRG", 305 | "sample_id": "train-7798" 306 | } 307 | 308 | ``` 309 | 310 | 311 | #### MedDG任务 312 | 313 | 本任务要求:给定医生和患者交流的对话历史,生成医生的下一句回复。MedDG的原始设置强调实体蕴含,即假定医生的下一句回复需要包含一定的实体。在PromptCBLUE下,我们不对实体蕴含进行要求,只评测LLM作为医生角色的回复生成效果。样例如下: 314 | 315 | ```json 316 | 317 | { 318 | "input": "患者:最近总是到了晚上就胃很难受。这几天吃过饭就有点反胃但是都是头痛头晕恶心。胃还是很难受(女,19岁)\n医生:你好,这种情况有多长时间了?\n患者:半个月了。\n医生:平时吃饭规律吗?\n根据上述对话历史,作为医生应该如何回复?\n答:", 319 | "target": "胃部感觉难受是怎么难受?反酸烧心打嗝?还是胃疼胃胀?", 320 | "answer_choices": null, 321 | "task_type": "response_generation", 322 | "task_dataset": "MedDG", 323 | "sample_id": "dev-19346" 324 | } 325 | ``` 326 | -------------------------------------------------------------------------------- /src/data/templates.json: -------------------------------------------------------------------------------- 1 | { 2 | "CMeEE-V2": [ 3 | "找出指定的实体:\\n[INPUT_TEXT]\\n类型选项:[LIST_LABELS]\\n答:", 4 | "找出指定的实体:\\n[INPUT_TEXT]\\n实体类型选项:[LIST_LABELS]\\n答:", 5 | "找出句子中的[LIST_LABELS]实体:\\n[INPUT_TEXT]\\n答:", 6 | "[INPUT_TEXT]\\n问题:句子中的[LIST_LABELS]实体是什么?\\n答:", 7 | "生成句子中的[LIST_LABELS]实体:\\n[INPUT_TEXT]\\n答:", 8 | "下面句子中的[LIST_LABELS]实体有哪些?\\n[INPUT_TEXT]\\n答:", 9 | "实体抽取:\\n[INPUT_TEXT]\\n选项:[LIST_LABELS]\\n答:", 10 | "医学实体识别:\\n[INPUT_TEXT]\\n实体选项:[LIST_LABELS]\\n答:" 11 | ], 12 | "CMeIE": [ 13 | "找出指定的三元组:\\n[INPUT_TEXT]\\n实体间关系:[LIST_LABELS]\\n答:", 14 | "根据给定的实体间的关系,抽取具有这些关系的实体对:\\n[INPUT_TEXT]\\n实体间关系标签:[LIST_LABELS]\\n答:", 15 | "找出句子中的具有[LIST_LABELS]关系类型的头尾实体对:\\n[INPUT_TEXT]\\n答:", 16 | "[INPUT_TEXT]\\n问题:句子中的[LIST_LABELS]等关系类型三元组是什么?\\n答:", 17 | "给出句子中的[LIST_LABELS]等关系类型的实体对:[INPUT_TEXT]\\n答:", 18 | "[INPUT_TEXT]\\n这个句子里面具有一定医学关系的实体组有哪些?\\n三元组关系选项:[LIST_LABELS]\\n答:", 19 | "同时完成实体识别与关系识别:\\n[INPUT_TEXT]\\n三元组关系类型:[LIST_LABELS]\\n答:" 20 | ], 21 | "CHIP-CDN": [ 22 | "给出下面诊断原词的标准化:\\n[INPUT_TEXT]\\n候选集:[LIST_LABELS]\\n说明:从候选的若干个ICD-10诊断标准词中选择出与原诊断描述匹配的词\\n答:", 23 | "找出归一后的标准词:\\n[INPUT_TEXT]\\n选项:[LIST_LABELS]\\n说明:从候选的若干个ICD-10诊断标准词中选择出与原诊断描述匹配的词\\n答:", 24 | "诊断归一化:\\n[INPUT_TEXT]\\n选项:[LIST_LABELS]\\n说明:从候选的若干个ICD-10诊断标准词中选择出与原诊断描述匹配的词\\n答:", 25 | "诊断实体的语义标准化:\\n[INPUT_TEXT]\\n实体选项:[LIST_LABELS]\\n说明:从候选的若干个ICD-10诊断标准词中选择出与原诊断描述匹配的词\\n答:", 26 | "给出诊断的归一化:\\n[INPUT_TEXT]\\n医学实体选项:[LIST_LABELS]\\n说明:从候选的若干个ICD-10诊断标准词中选择出与原诊断描述匹配的词\\n答:", 27 | "[INPUT_TEXT]\\n归一化后的标准词是?\\n实体选项:[LIST_LABELS]\\n说明:从候选的若干个ICD-10诊断标准词中选择出与原诊断描述匹配的词\\n答:", 28 | "实体归一化:\\n[INPUT_TEXT]\\n实体候选:[LIST_LABELS]\\n说明:从候选的若干个ICD-10诊断标准词中选择出与原诊断描述匹配的词\\n答:" 29 | ], 30 | "CHIP-CDEE": [ 31 | "临床发现事件抽取:\\n[INPUT_TEXT]\\n说明:临床发现事件的主体词包含发生状态,描述词和解剖部位这三种属性,其中描述词和解剖部位可能有多个值\\n答:", 32 | "找出指定的临床发现事件属性:\\n[INPUT_TEXT]\\n事件抽取说明:临床发现事件由主体词,发生状态,描述词和解剖部位组成\\n答:", 33 | "找出句子中的临床发现事件及其属性:\\n [INPUT_TEXT]\\n说明:临床发现事件的主体词包含发生状态,描述词和解剖部位这三种属性,其中描述词和解剖部位可能有多个值\\n答:", 34 | "[INPUT_TEXT]\\n问题:句子中的临床发现事件及其属性是什么?\\n说明:临床发现事件由主体词,发生状态,描述词和解剖部位组成\\n答:", 35 | "生成句子中的临床发现事件属性是:\\n[INPUT_TEXT]\\n说明:临床发现事件的主体词包含发生状态,描述词和解剖部位这三种属性,其中描述词和解剖部位可能有多个值\\n答:", 36 | "[INPUT_TEXT]\\n这个句子里面临床发现事件是?\\n说明:临床发现事件由主体词,发生状态,描述词和解剖部位组成\\n答:", 37 | "临床发现事件抽取:[INPUT_TEXT]\\n说明:临床发现事件的主体词包含发生状态,描述词和解剖部位这三种属性,其中描述词和解剖部位可能有多个值\\n答:" 38 | ], 39 | "CHIP-STS": [ 40 | "以下两句话的意思相同的吗?\\n“[INPUT_TEXT_1]”,“[INPUT_TEXT_2]”\\n选项:是的,不是\\n答:", 41 | "我想知道下面两句话的意思是否相同。\\n“[INPUT_TEXT_1]”,“[INPUT_TEXT_2]”。\\n选项:是的,不是\\n答:", 42 | "我是否可以用以下的句子:“[INPUT_TEXT_1]”,来替换这个句子:“[INPUT_TEXT_2]”,并且它们有相同的意思?\\n选项:是的,不是\\n答:", 43 | "下面两个句子语义是“相同”或“不同”?\\n“[INPUT_TEXT_1]”,“[INPUT_TEXT_2]”。\\n选项:相同,不同\\n答:", 44 | "“[INPUT_TEXT_1]”和“[INPUT_TEXT_2]”是同一个意思吗?\\n选项:是的,不是\\n答:", 45 | "“[INPUT_TEXT_1]”,“[INPUT_TEXT_2]”。\\n这两句是一样的意思吗?\\n选项:是的,不是\\n答:" 46 | ], 47 | "CHIP-CTC": [ 48 | "判断临床试验筛选标准的类型:\\n[INPUT_TEXT]\\n选项:[LIST_LABELS]\\n答:", 49 | "确定试验筛选标准的类型:\\n[INPUT_TEXT]\\n类型选项:[LIST_LABELS]\\n答:", 50 | "[INPUT_TEXT]\\n这句话是什么临床试验筛选标准类型?\\n类型选项:[LIST_LABELS]\\n答:", 51 | "[INPUT_TEXT]\\n是什么临床试验筛选标准类型?\\n选项:[LIST_LABELS]\\n答:", 52 | "请问是什么类型?\\n[INPUT_TEXT]\\n临床试验筛选标准选项:[LIST_LABELS]\\n答:" 53 | ], 54 | "KUAKE-IR": [ 55 | "以下回答内容是否与这里的医疗搜索相关?\\n医疗搜索:[INPUT_TEXT_1]\\n回答内容:[INPUT_TEXT_2]\\n选项: [LIST_LABELS]\\n答:", 56 | "医疗搜索:[INPUT_TEXT_1]\\n以下回答内容是否能够回答搜索问题?\\n回答内容:[INPUT_TEXT_2]\\n选项: [LIST_LABELS]\\n答:", 57 | "医疗搜索:[INPUT_TEXT_1]\\n回答内容:[INPUT_TEXT_2]\\n上述搜索和回答是否相关?\\n选项: [LIST_LABELS]\\n答:" 58 | ], 59 | "KUAKE-QIC": [ 60 | "判断下面搜索词的意图:\\n[INPUT_TEXT]\\n选项:[LIST_LABELS]\\n答:", 61 | "确定检索词的类型:\\n[INPUT_TEXT]\\n类型选项:[LIST_LABELS]\\n答:", 62 | "[INPUT_TEXT]\\n这个搜索是什么意图?\\n类型选项:[LIST_LABELS]\\n答:", 63 | "[INPUT_TEXT]\\n这个医疗搜索词是什么意图分类?\\n选项:[LIST_LABELS]\\n答:", 64 | "请问是什么意图类型?\\n[INPUT_TEXT]\\n搜索意图选项:[LIST_LABELS]\\n答:" 65 | ], 66 | "KUAKE-QQR": [ 67 | "判断两个查询所表述的主题的匹配程度:\\n“[INPUT_TEXT_1]”,“[INPUT_TEXT_2]”。\\n选项:[LIST_LABELS]\\n答:", 68 | "我想知道下面两个搜索词的意思有多相同。\\n“[INPUT_TEXT_1]”,“[INPUT_TEXT_2]”。\\n选项:[LIST_LABELS]\\n答:", 69 | "下面两个句子的语义关系是?\\n“[INPUT_TEXT_1]”,“[INPUT_TEXT_2]”。\\n选项: [LIST_LABELS]\\n答:", 70 | "“[INPUT_TEXT_1]”和“[INPUT_TEXT_2]”表述的主题完全一致吗?\\n选项:[LIST_LABELS]\\n答:", 71 | "“[INPUT_TEXT_1]”和“[INPUT_TEXT_2]”的意思有多相似?\\n选项:[LIST_LABELS]\\n答:", 72 | "“[INPUT_TEXT_1]”,“[INPUT_TEXT_2]”。\\n这两句是一样的意思吗?\\n选项:[LIST_LABELS]\\n答:", 73 | "“[INPUT_TEXT_1]”,“[INPUT_TEXT_2]”。\\n这两句的语义关系是?\\n选项:[LIST_LABELS]\\n答:" 74 | ], 75 | "KUAKE-QTR": [ 76 | "以下两句话的意思相同的吗?\\n“[INPUT_TEXT_1]”,“[INPUT_TEXT_2]”。\\n选项:[LIST_LABELS]\\n答:", 77 | "我想知道下面两句话的意思有多相似。\\n“[INPUT_TEXT_1]”,“[INPUT_TEXT_2]”\\n选项:[LIST_LABELS]\\n答:", 78 | "下面的搜索词和页面标签的意思有多相同?\\n搜索词:[INPUT_TEXT_1]\\n页面标签:[INPUT_TEXT_2]\\n选项:[LIST_LABELS]\\n答:", 79 | "下面两个句子的语义相似程度是[LIST_LABELS]中的哪一种?\\n“[INPUT_TEXT_1]”,“[INPUT_TEXT_2]”\\n答:", 80 | "“[INPUT_TEXT_1]”和“[INPUT_TEXT_2]”是同一个意思吗?\\n选项:[LIST_LABELS]\\n答:", 81 | "“[INPUT_TEXT_1]”和“[INPUT_TEXT_2]”的意思有多相似?\\n选项:[LIST_LABELS]\\n答:", 82 | "“[INPUT_TEXT_1]”,“[INPUT_TEXT_2]”。\\n这两句话的意思的匹配程度如何?\\n选项:[LIST_LABELS]\\n答:", 83 | "搜索词:“[INPUT_TEXT_1]”。页面标题:“[INPUT_TEXT_2]”。这两句是一样的意思吗?选项:[LIST_LABELS]\\n答:" 84 | ], 85 | "CHIP-MDCFNPC": [ 86 | "医疗对话中临床发现实体的阴阳性判别:\\n[INPUT_TEXT]\\n临床发现实体:[LIST_MENTIONS]\\n阴阳性选项:[LIST_LABELS]\\n说明:临床发现是临床医学下,病人状态描述的概念集合\\n答:", 87 | "给出对话中临床发现实体的阴阳性判断:\\n[INPUT_TEXT]\\n临床发现实体:[LIST_MENTIONS]\\n阴阳性选项:[LIST_LABELS]\\n说明:临床发现是临床医学下,病人状态描述的概念集合\\n答:", 88 | "对下述对话中的临床发现标识阴阳性:\\n[INPUT_TEXT]\\n临床发现实体:[LIST_MENTIONS]\\n阴阳性选项:[LIST_LABELS]\\n说明:临床发现是临床医学下,病人状态描述的概念集合\\n答:", 89 | "[INPUT_TEXT]\\n问题:我们已经给出问诊对话中的临床发现,请问这些实体的阴阳性是?\\n临床发现实体:[LIST_MENTIONS]\\n阴阳性选项:[LIST_LABELS]\\n说明:临床发现是临床医学下,病人状态描述的概念集合\\n答:" 90 | ], 91 | "IMCS-V2-DAC": [ 92 | "判断下面问诊对话最后一句的意图:\\n[INPUT_TEXT]\\n选项:[LIST_LABELS]\\n答:", 93 | "确定问诊对话当前句子的意图:\\n[INPUT_TEXT]\\n类型选项:[LIST_LABELS]\\n答:", 94 | "[INPUT_TEXT]\\n问诊对话的最后一话是什么意图?\\n类型选项:[LIST_LABELS]\\n答:", 95 | "[INPUT_TEXT]\\n最后的问诊句子是什么意图分类?\\n选项:[LIST_LABELS]\\n答:", 96 | "请问最后一句对话是什么意图类型?\\n[INPUT_TEXT]\\n意图选项:[LIST_LABELS]\\n答:" 97 | ], 98 | "IMCS-V2-NER": [ 99 | "找出下面问诊语句中的[LIST_LABELS]实体:\\n[INPUT_TEXT]\\n答:", 100 | "找出下面句子中的[LIST_LABELS]实体:\\n[INPUT_TEXT]\\n答:", 101 | "[INPUT_TEXT]\\n问题:上述对话中的[LIST_LABELS]实体是哪些?\\n答:", 102 | "[INPUT_TEXT]\\n上述问诊中的[LIST_LABELS]实体是什么?\\n答:", 103 | "下面对话中的[LIST_LABELS]实体有哪些?\\n[INPUT_TEXT]\\n答:", 104 | "问诊对话的实体抽取:[INPUT_TEXT]\\n选项:[LIST_LABELS]\\n答:" 105 | ], 106 | "IMCS-V2-SR": [ 107 | "找出当前对话中的症状,并判断阴阳性:\\n[INPUT_TEXT]\\n症状阴阳性选项:[LIST_LABELS]\\n答:", 108 | "根据对话历史和当前对话,抽取症状实体,以及这些实体的阴阳性:\\n[INPUT_TEXT]\\n候选:[LIST_LABELS]\\n答:", 109 | "当前对话涉及哪些症状?这些症状的阴阳性如何?\\n[INPUT_TEXT]\\n选项:[LIST_LABELS]\\n答:", 110 | "[INPUT_TEXT]\\n根据上述对话,症状有哪些?这些症状的阴阳性是?\\n选项:[LIST_LABELS]\\n答:" 111 | ], 112 | "IMCS-V2-MRG": [ 113 | "根据下面的问诊对话自动生成对应的诊疗报告:\\n[INPUT_TEXT]\\n说明:诊疗报告分为主诉, 现病史, 辅助检查, 既往史, 诊断, 建议这六个章节。\\n答:", 114 | "帮助患者自动总结问诊的诊疗报告:\\n[INPUT_TEXT]\\n说明:诊疗报告分为主诉, 现病史, 辅助检查, 既往史, 诊断, 建议这六个章节。\\n答:", 115 | "总结下面问诊对话,并给出问诊报告\\n[INPUT_TEXT]\\n说明:诊疗报告分为主诉, 现病史, 辅助检查, 既往史, 诊断, 建议这六个章节。\\n答:", 116 | "[INPUT_TEXT]\\n根据上述对话,给出诊疗报告\\n说明:诊疗报告分为主诉, 现病史, 辅助检查, 既往史, 诊断, 建议这六个章节。\\n答:" 117 | ], 118 | "MedDG": [ 119 | "根据医生和患者交流的对话历史预测出医生的下一句回复:\\n[INPUT_TEXT]\\n答:", 120 | "自动生成问诊对话中的医生下一句回复:\\n[INPUT_TEXT]\\n答:", 121 | "根据下面的问诊对话历史,给出医生的下一句回复\\n[INPUT_TEXT]\\n答:", 122 | "[INPUT_TEXT]\\n根据上述对话历史,给出医生的下一句话\\n答:", 123 | "[INPUT_TEXT]\\n根据上述对话历史,作为医生应该如何回复?\\n答:" 124 | ] 125 | } -------------------------------------------------------------------------------- /src/data/结构化预测结果格式说明.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ### 结构化预测结果格式说明 4 | 5 | 我们现在对评测参与者需要提交的结构化预测结果results.json文件进行详细的说明。results.json文件整体可以采用json.load()方法加载。其内部结构如下: 6 | ```bash 7 | 8 | { 9 | "task_name1": [ 10 | { 11 | "sample_id": str, 12 | "answer": answer_sample, 13 | } 14 | ], 15 | "task_name2": [ 16 | { 17 | "sample_id": str, 18 | "answer": answer_sample, 19 | } 20 | ], 21 | "task_name3": [ 22 | { 23 | "sample_id": str, 24 | "answer": answer_sample, 25 | } 26 | ], 27 | } 28 | ``` 29 | 其中answer_sample的格式因任务不同而格式各异。`task_name1`, `task_name2`,... 的取值为['CMeEE-V2', 'CMeIE', 'CHIP-CDN', 'CHIP-CDEE', 'CHIP-STS', 'CHIP-CTC', 'CHIP-MDCFNPC', 'KUAKE-IR', 'KUAKE-QIC', 'KUAKE-QQR', 'KUAKE-QTR', 'MedDG', 'IMCS-V2-MRG', 'IMCS-V2-NER', 'IMCS-V2-DAC', 'IMCS-V2-SR']。 30 | 31 | 评测参与队伍可以参考[dev.json](./datasets/PromptCBLUE/toy_examples/dev.json)和 [dev_structured.json](./datasets/PromptCBLUE/toy_examples/dev_structured.json)文件来理解由LLM输出到评测规定的结构化格式的转化。 32 | 33 | 我们现在分各个任务说明`answer_sample`的格式。 34 | 35 | #### CMeEE-V2任务 36 | 37 | `answer_sample`为list,list中每个元素包含两个字段: entity 和 type。entity是文本中的医学实体mention,type为样本提示/指令中规定的医学实体类型名称。 38 | 39 | ```bash 40 | answer_sample = [ 41 | { 42 | "entity": str, 43 | "type": str 44 | } 45 | ] 46 | ``` 47 | 48 | #### CMeIE任务 49 | 50 | `answer_sample`为list,list中每个元素包含三个字段: subject是头实体提及,object是尾实体提及,predicate是样本提示/指令中规定的实体间关系类型名称。 51 | 52 | ```bash 53 | answer_sample = [ 54 | { 55 | "predicate": str, 56 | "subject": str, 57 | "object": str 58 | } 59 | ] 60 | ``` 61 | 62 | 63 | #### CHIP-CDEE任务 64 | 65 | `answer_sample`为list,list中每个元素包含四个字段: 医学临床事件的`主体词`字段,`发生状态`字段,`描述词`, `解剖部位`字段。`主体词`字段和`发生状态`字段都是字符串。`描述词`, `解剖部位`字段,都是非空字符串的列表。 66 | 67 | ```bash 68 | answer_sample = [ 69 | { 70 | "主体词": str, 71 | "发生状态": str, 72 | "描述词": [ 73 | str 74 | ], 75 | "解剖部位": [ 76 | str 77 | ] 78 | } 79 | ] 80 | ``` 81 | 82 | 83 | 84 | #### CHIP-CDN任务 85 | 86 | `answer_sample`为list,list中每个元素包含两个字段: entity为ICD标准词库中的词条,"type"字段取值固定为"normalization"。 87 | 88 | ```bash 89 | answer_sample = [ 90 | { 91 | "entity": str, 92 | "type": "normalization" 93 | } 94 | ] 95 | ``` 96 | 97 | 98 | 99 | #### 分类型任务 100 | 101 | 对CHIP-CTC任务,KUAKE-QIC,IMCS-V2-DAC任务,`answer_sample`为str,取值为: 样本提示/指令中规定的分类类型名称,或者是"非上述类型"。 102 | 103 | ```bash 104 | answer_sample = str 105 | ``` 106 | 107 | 108 | #### 文本对任务 109 | 110 | 对CHIP-STS任务,KUAKE-QTR,KUAKE-QQR,KUAKE-IR任务, `answer_sample`为str,取值为: 样本提示/指令中规定的类型标签名称。 111 | 112 | ```bash 113 | answer_sample = str 114 | ``` 115 | 116 | #### CHIP-MDCFNPC任务 117 | 118 | `answer_sample`为list,list中每个元素包含两个字段: entity为对话中的症状词,"attr"字段取值必须是: 样本提示/指令中规定的属性类型标签名称。 119 | 120 | ```bash 121 | answer_sample = [ 122 | { 123 | "entity": str, 124 | "attr": str 125 | } 126 | ] 127 | ``` 128 | 129 | 130 | #### IMCS-V2-NER任务 131 | 132 | `answer_sample`为list,list中每个元素包含两个字段: entity 和 type。entity是文本中的医学实体mention,type为样本提示/指令中规定的医学实体类型名称。 133 | 134 | ```bash 135 | answer_sample = [ 136 | { 137 | "entity": str, 138 | "type": str 139 | } 140 | ] 141 | ``` 142 | 143 | 144 | #### IMCS-V2-SR任务 145 | 146 | `answer_sample`为list,list中每个元素包含两个字段: entity为对话中的症状词,"attr"字段取值必须是: 样本提示/指令中规定的属性类型标签名称。 147 | 148 | ```bash 149 | answer_sample = [ 150 | { 151 | "entity": str, 152 | "attr": str 153 | } 154 | ] 155 | ``` 156 | 157 | 158 | #### IMCS-V2-MRG任务 159 | 160 | `answer_sample`为str,取值为: 模型生成的诊断报告。 161 | 162 | ```bash 163 | answer_sample = str 164 | ``` 165 | 166 | 167 | #### MedDG任务 168 | 169 | `answer_sample`为str,取值为: 模型生成的对话回复。 170 | 171 | ```bash 172 | answer_sample = str 173 | ``` 174 | 175 | 176 | -------------------------------------------------------------------------------- /src/download_checkpoints.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import snapshot_download 2 | 3 | repo_id = "THUDM/visualglm-6b" 4 | downloaded = snapshot_download( 5 | repo_id, 6 | cache_dir="./", 7 | ) -------------------------------------------------------------------------------- /src/evaluation/README.txt: -------------------------------------------------------------------------------- 1 | 2 | ## 验证:LLM回复转化为结构化格式的代码(参赛选手需要根据自己的LLM输出格式编写格式转化代码): 3 | python post_generate_process.py dev_predictions.json results.json 4 | 5 | ## 评分 6 | ./py_entrance.sh input_param.json eval_result.json 7 | cat eval_result.json 8 | 9 | 或者运行 10 | python evaluate.py input_param.json eval_result.json 11 | cat eval_result.json 12 | 13 | ## 参考 14 | 更多详细的数据说明和baseline方法实现,见https://github.com/michael-wzhu/PromptCBLUE -------------------------------------------------------------------------------- /src/evaluation/evaluators.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | from sklearn.metrics import classification_report 4 | from transformers import BasicTokenizer 5 | from rouge_chinese import Rouge 6 | 7 | from text2dt_eval_func import text2dt_eval_single_tree 8 | 9 | basic_tokenizer = BasicTokenizer(tokenize_chinese_chars=True) 10 | 11 | 12 | def calc_info_extract_task_scores(list_structured_golden, 13 | list_structured_predict): 14 | 15 | assert len(list_structured_golden) == len(list_structured_predict) 16 | 17 | tp = 0 18 | fp = 0 19 | fn = 0 20 | for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict): 21 | assert samp_golden["sample_id"] == samp_predict["sample_id"], "sample ordering is wrong!" 22 | answer_golden = samp_golden["answer"] 23 | answer_predict = samp_predict["answer"] 24 | 25 | assert isinstance(answer_golden, list) 26 | assert isinstance(answer_predict, list), "sample format is wrong!" 27 | 28 | set_golden = set() 29 | for inst in answer_golden: 30 | assert isinstance(inst, dict) 31 | keys = sorted(list(inst.keys())) 32 | inst = tuple([json.dumps(inst[w], ensure_ascii=False) for w in keys ]) 33 | # inst = list(inst.items()) 34 | # inst.sort() 35 | # inst = tuple(inst) 36 | 37 | set_golden.add(inst) 38 | 39 | set_predict = set() 40 | for inst in answer_predict: 41 | assert isinstance(inst, dict) 42 | keys = sorted(list(inst.keys())) 43 | # inst = tuple([inst[w] for w in keys]) 44 | inst = tuple([json.dumps(inst[w], ensure_ascii=False) for w in keys]) 45 | 46 | # inst = list(inst.items()) 47 | # inst.sort() 48 | # inst = tuple(inst) 49 | 50 | set_predict.add(inst) 51 | 52 | # print("set_predict: ", set_predict) 53 | # print("set_golden: ", set_golden) 54 | 55 | tp += len(set_golden.intersection(set_predict)) 56 | fp += len(set_predict.difference(set_golden)) 57 | fn += len(set_golden.difference(set_predict)) 58 | 59 | if tp: 60 | precision = tp / (tp + fp) 61 | recall = tp / (tp + fn) 62 | f1 = 2 * precision * recall / (precision + recall) 63 | 64 | else: 65 | precision, recall, f1 = 0, 0, 0 66 | 67 | return precision, recall, f1 68 | 69 | 70 | def calc_cls_task_scores(list_structured_golden, 71 | list_structured_predict, 72 | list_labels=None, 73 | return_macro=False, 74 | ): 75 | # types = list_labels 76 | # scores = {c: {"tp": 0, "fp": 0, "fn": 0, "tn": 0} for c in list_labels + ["ALL"]} 77 | 78 | predictions = [] 79 | ground_truths = [] 80 | 81 | # Count GT relations and Predicted relations 82 | assert len(list_structured_golden) == len(list_structured_predict) 83 | n_sents = len(list_structured_golden) 84 | 85 | # Count TP, FP and FN per type 86 | for pred_samp, gt_samp in zip(list_structured_predict, list_structured_golden): 87 | assert pred_samp["sample_id"] == gt_samp["sample_id"], "sample ordering is wrong!" 88 | 89 | pred_label = pred_samp["answer"] 90 | gt_label = gt_samp["answer"] 91 | assert gt_label != "" 92 | if pred_label == "": 93 | pred_label = list_labels[0] 94 | 95 | predictions.append(pred_label) 96 | ground_truths.append(gt_label) 97 | 98 | # metric 99 | t0 = time.time() 100 | cls_report = classification_report( 101 | ground_truths, predictions, 102 | output_dict=True, 103 | zero_division=0, 104 | ) 105 | # print(cls_report) 106 | 107 | t1 = time.time() 108 | # print("calculation metrics: ", t1 - t0) 109 | 110 | if return_macro: 111 | return cls_report["macro avg"]["precision"], \ 112 | cls_report["macro avg"]["recall"], \ 113 | cls_report["macro avg"]["f1-score"] 114 | else: 115 | return cls_report["weighted avg"]["precision"], \ 116 | cls_report["weighted avg"]["recall"], \ 117 | cls_report["weighted avg"]["f1-score"] 118 | 119 | 120 | def calc_nlg_task_scores(list_structured_golden, list_structured_predict): 121 | 122 | 123 | assert len(list_structured_golden) == len(list_structured_predict) 124 | 125 | scores = [] 126 | predictions = [] 127 | references = [] 128 | for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict): 129 | # print("samp_golden: ", samp_golden) 130 | # print("samp_predict: ", samp_predict) 131 | 132 | assert samp_golden["sample_id"] == samp_predict["sample_id"], "sample ordering is wrong!" 133 | answer_golden = samp_golden["answer"] 134 | answer_predict = samp_predict["answer"] 135 | 136 | assert isinstance(answer_golden, str) 137 | assert isinstance(answer_predict, str), "sample format is wrong!" 138 | 139 | # basic tokenizer: 拆分中文字,保留英文单词 140 | answer_predict = basic_tokenizer.tokenize(answer_predict) 141 | answer_golden = basic_tokenizer.tokenize(answer_golden) 142 | answer_predict = " ".join(answer_predict).strip() 143 | answer_golden = " ".join(answer_golden).strip() 144 | if answer_golden.strip() == "": 145 | answer_golden = "无 。" 146 | if answer_predict.strip() == "": 147 | answer_predict = "无 。" 148 | # print("answer_predict: ", answer_predict) 149 | # print("answer_golden: ", answer_golden) 150 | 151 | predictions.append(answer_predict) 152 | references.append(answer_golden) 153 | 154 | rouge = Rouge() 155 | scores = rouge.get_scores(predictions, references, avg=True) 156 | 157 | rouge1 = scores["rouge-1"]["f"] 158 | rouge2 = scores["rouge-2"]["f"] 159 | rougeL = scores["rouge-l"]["f"] 160 | 161 | return rouge1, rouge2, rougeL 162 | 163 | 164 | def calc_nlg_task_scores_by_sessions(list_structured_golden, list_structured_predict): 165 | 166 | 167 | assert len(list_structured_golden) == len(list_structured_predict) 168 | 169 | scores = [] 170 | predictions = [] 171 | references = [] 172 | for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict): 173 | # print("samp_golden: ", samp_golden) 174 | # print("samp_predict: ", samp_predict) 175 | 176 | assert samp_golden["sample_id"] == samp_predict["sample_id"], "sample ordering is wrong!" 177 | answer_golden = samp_golden["answer"] 178 | answer_predict = samp_predict["answer"] 179 | 180 | # if set(answer_golden.keys()) != set(answer_predict.keys()) 181 | 182 | for key in answer_golden.keys(): 183 | pred = answer_predict.get(key, "").strip() 184 | gt = answer_golden[key].strip() 185 | 186 | # basic tokenizer: 拆分中文字,保留英文单词 187 | pred = basic_tokenizer.tokenize(pred) 188 | gt = basic_tokenizer.tokenize(gt) 189 | pred = " ".join(pred).strip() 190 | gt = " ".join(gt).strip() 191 | if gt.strip() == "": 192 | gt = "无 。" 193 | if pred.strip() == "": 194 | pred = "无 。" 195 | 196 | # if gt != pred: 197 | # print(gt) 198 | # print(pred) 199 | 200 | predictions.append( 201 | pred 202 | ) 203 | references.append( 204 | gt 205 | ) 206 | 207 | rouge = Rouge() 208 | scores = rouge.get_scores(predictions, references, avg=True) 209 | rouge1 = scores["rouge-1"]["f"] 210 | rouge2 = scores["rouge-2"]["f"] 211 | rougeL = scores["rouge-l"]["f"] 212 | 213 | return rouge1, rouge2, rougeL 214 | 215 | 216 | def calc_text2dt_task_scores(list_structured_golden, 217 | list_structured_predict,): 218 | 219 | assert len(list_structured_golden) == len(list_structured_predict) 220 | 221 | gold_tree_num, correct_tree_num = 0.000001, 0.000001 222 | gold_triplet_num, predict_triplet_num, correct_triplet_num = 0.000001, 0.000001, 0.000001 223 | gold_path_num, predict_path_num, correct_path_num = 0.000001, 0.000001, 0.000001 224 | gold_node_num, predict_node_num, correct_node_num = 0.000001, 0.000001, 0.000001 225 | 226 | edit_dis = 0 227 | max_edit_dis = 0 228 | 229 | for samp_golden, samp_predict in zip(list_structured_golden, list_structured_predict): 230 | assert samp_golden["sample_id"] == samp_predict["sample_id"], "sample ordering is wrong!" 231 | tree_golden = samp_golden["answer"] 232 | tree_predict = samp_predict["answer"] 233 | 234 | assert isinstance(tree_golden, list) 235 | assert isinstance(tree_predict, list), "sample format is wrong!" 236 | 237 | tmp = text2dt_eval_single_tree(tree_predict, tree_golden) 238 | gold_tree_num += tmp[0] 239 | correct_tree_num += tmp[1] 240 | correct_triplet_num += tmp[2] 241 | predict_triplet_num += tmp[3] 242 | gold_triplet_num += tmp[4] 243 | correct_path_num += tmp[5] 244 | predict_path_num += tmp[6] 245 | gold_path_num += tmp[7] 246 | edit_dis += tmp[8] 247 | 248 | # 计算最大编辑数 249 | max_edit_dis += (tmp[3] + tmp[10] * 2) + (tmp[4] + tmp[11] * 2) 250 | 251 | correct_node_num += tmp[9] 252 | predict_node_num += tmp[10] 253 | gold_node_num += tmp[11] 254 | 255 | tree_acc = correct_tree_num / gold_tree_num 256 | triplet_f1 = 2 * (correct_triplet_num / predict_triplet_num) * (correct_triplet_num / gold_triplet_num) / ( 257 | correct_triplet_num / predict_triplet_num + correct_triplet_num / gold_triplet_num) 258 | path_f1 = 2 * (correct_path_num / predict_path_num) * (correct_path_num / gold_path_num) / ( 259 | correct_path_num / predict_path_num + correct_path_num / gold_path_num) 260 | tree_lenv_radio = 1 - edit_dis / max_edit_dis 261 | node_f1 = 2 * (correct_node_num / predict_node_num) * (correct_node_num / gold_node_num) / ( 262 | correct_node_num / predict_node_num + correct_node_num / gold_node_num) 263 | 264 | return tree_lenv_radio, node_f1, path_f1 265 | 266 | -------------------------------------------------------------------------------- /src/evaluation/input_param.json: -------------------------------------------------------------------------------- 1 | { 2 | "fileData":{ 3 | "evaluatorDir":"", 4 | "evaluatorPath":"", 5 | "standardFileDir":"", 6 | "standardFilePath":"dev_structured.json", 7 | "userFileDir":"", 8 | "userFilePath":"results.json" 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /src/evaluation/text2dt_eval_func.py: -------------------------------------------------------------------------------- 1 | 2 | # 将符合诊疗决策树约束的节点前序序列转化为代表诊疗决策树结构的节点矩阵,matrix[i][j]='F'/'L'/'R'表示第j个节点是第i个节点的父/左子/右子节点 3 | import copy 4 | 5 | 6 | def nodematrix(tree): 7 | nodelist=[] 8 | for i in range(len(tree)): 9 | nodelist.append(tree[i]["role"]) 10 | node_matrix = [[0 for i in range(len(nodelist))] for j in range(len(nodelist))] 11 | 12 | # if len(tree) == 0: 13 | # return (node_matrix) 14 | 15 | count = 0 16 | while (nodelist[0] != 'D'): 17 | for i in range(len(nodelist)): 18 | if nodelist[i] == 'C': 19 | flag, leaf1, leaf2 = 0, 0, 0 20 | for j in range(i+1,len(nodelist)): 21 | if nodelist[j]=='D' and flag==0: 22 | flag = 1 23 | leaf1 = j 24 | elif nodelist[j]=='X' : 25 | continue 26 | elif nodelist[j]=='D' and flag==1: 27 | #print(i) 28 | leaf2 = j 29 | nodelist[i]='D' 30 | node_matrix[leaf1][i] = 'F' 31 | node_matrix[leaf2][i] = 'F' 32 | node_matrix[i][leaf1] = 'L' 33 | node_matrix[i][leaf2] = 'R' 34 | for k in range(i+1, leaf2+1): 35 | nodelist[k]='X' 36 | flag = 2 37 | break 38 | elif nodelist[j] == 'C': 39 | break 40 | if flag == 2: 41 | break 42 | 43 | count += 1 44 | if count > 100: 45 | break 46 | 47 | return(node_matrix) 48 | 49 | # 计算两个节点的距离 50 | def node_dis(node1,node2): 51 | if node2 is None : 52 | #node2 = {"role": "", "triples": [], "logical_rel": ""} 53 | node2 = {"role": "", "triples": [], "logical_rel": "null"} 54 | dis = 0 55 | if node1["role"] != node2["role"]: 56 | dis += 1 57 | #print(dis) 58 | if node1["logical_rel"] != node2["logical_rel"]: 59 | dis += 1 60 | dis += len(list((set(node1["triples"])|set(node2["triples"]))-(set(node1["triples"])&set(node2["triples"])))) 61 | return(dis) 62 | 63 | # 判断两条路径是否相同 64 | def is_path_equal(path1,path2): 65 | if (len(path1)!=len(path2)): 66 | return False 67 | for i in range(len(path1)): 68 | if isinstance(path1[i],dict) and isinstance(path2[i],dict): 69 | if path1[i]['role'] == path2[i]['role'] and path1[i]['logical_rel'] == path2[i]['logical_rel'] and \ 70 | set(path1[i]['triples']) == set(path2[i]['triples']): 71 | continue 72 | else: 73 | return False 74 | elif path1[i] != path2[i]: 75 | return False 76 | return True 77 | 78 | # 判断两棵树是否相同 79 | def is_tree_equal(predict_tree,gold_tree): 80 | if len(predict_tree) != len(gold_tree): 81 | return 0 82 | else: 83 | for i in range(len(predict_tree)): 84 | if predict_tree[i]['role'] == gold_tree[i]['role'] and \ 85 | predict_tree[i]['logical_rel'] == gold_tree[i]['logical_rel'] and \ 86 | set(predict_tree[i]['triples']) == set(gold_tree[i]['triples']): 87 | continue 88 | else: 89 | return 0 90 | return 1 91 | 92 | # 计算模型预测的诊疗决策树和ground turth的距离,距离越小表示两树越相似,为计算编辑比率做准备 93 | def edit_distance(predict_tree, gold_tree, predict_matrix, gold_matrix): 94 | dis = 0 95 | stack1 = [0] 96 | stack2 = [0] 97 | 98 | try: 99 | while stack1: 100 | s1=stack1.pop() 101 | s2=stack2.pop() 102 | if ('L' not in predict_matrix[s1] and 'R' not in predict_matrix[s1]) \ 103 | and ('L' in gold_matrix[s2] or 'R' in gold_matrix[s2]): 104 | dis += node_dis(predict_tree[s1], gold_tree[s2]) 105 | stack_tmp=[] 106 | stack_tmp.append(gold_matrix[s2].index('R')) 107 | stack_tmp.append(gold_matrix[s2].index('L')) 108 | while stack_tmp: 109 | s_tmp=stack_tmp.pop() 110 | dis += node_dis(gold_tree[s_tmp],None) 111 | if ('L' in gold_matrix[s_tmp] and 'R' in gold_matrix[s_tmp]): 112 | stack_tmp.append(gold_matrix[s_tmp].index('R')) 113 | stack_tmp.append(gold_matrix[s_tmp].index('L')) 114 | elif ('L' in predict_matrix[s1] and 'R' in predict_matrix[s1]) \ 115 | and ('L' not in gold_matrix[s2] or 'R' not in gold_matrix[s2]): 116 | dis += node_dis(predict_tree[s1], gold_tree[s2]) 117 | stack_tmp=[] 118 | stack_tmp.append(predict_matrix[s1].index('R')) 119 | stack_tmp.append(predict_matrix[s1].index('L')) 120 | while stack_tmp: 121 | s_tmp=stack_tmp.pop() 122 | dis += node_dis(predict_tree[s_tmp], None) 123 | if ('L' in predict_matrix[s_tmp] and 'R' in predict_matrix[s_tmp]): 124 | stack_tmp.append(predict_matrix[s_tmp].index('R')) 125 | stack_tmp.append(predict_matrix[s_tmp].index('L')) 126 | elif ('L' not in predict_matrix[s1] and 'R' not in predict_matrix[s1]) and \ 127 | ('L' not in gold_matrix[s2] and 'R' not in gold_matrix[s2]): 128 | dis += node_dis(predict_tree[s1], gold_tree[s2]) 129 | else: 130 | stack1.append(predict_matrix[s1].index('R')) 131 | stack1.append(predict_matrix[s1].index('L')) 132 | stack2.append(gold_matrix[s2].index('R')) 133 | stack2.append(gold_matrix[s2].index('L')) 134 | dis += node_dis(predict_tree[s1], gold_tree[s2]) 135 | 136 | except Exception as e: 137 | print("calculating edit dist wrong!") 138 | print(e) 139 | 140 | return dis 141 | 142 | # 计算决策路径抽取的TP,TP+FP,TP+FN 143 | def decision_path(predict_tree, gold_tree, predict_matrix, gold_matrix): 144 | leaf1, leaf2, paths1, paths2 = [], [], [], [] 145 | 146 | try: 147 | for i in range(len(predict_matrix)): 148 | if ('L' not in predict_matrix[i] and 'R' not in predict_matrix[i]): 149 | leaf1.append(i) 150 | for node in leaf1: 151 | path=[predict_tree[node]] 152 | while node !=0: 153 | #print(predict_matrix) 154 | #print(node) 155 | #print(predict_matrix[node]) 156 | path.append(predict_matrix[predict_matrix[node].index('F')][node]) 157 | path.append(predict_tree[predict_matrix[node].index('F')]) 158 | node =predict_matrix[node].index('F') 159 | paths1.append(path) 160 | for i in range(len(gold_matrix)): 161 | if ('L' not in gold_matrix[i] and 'R' not in gold_matrix[i]): 162 | leaf2.append(i) 163 | for node in leaf2: 164 | path=[gold_tree[node]] 165 | while node != 0: 166 | path.append(gold_matrix[gold_matrix[node].index('F')][node]) 167 | path.append(gold_tree[gold_matrix[node].index('F')]) 168 | node =gold_matrix[node].index('F') 169 | paths2.append(path) 170 | res = 0 171 | for path1 in paths1: 172 | for path2 in paths2: 173 | if is_path_equal(path1, path2): 174 | res += 1 175 | break 176 | except Exception as e: 177 | print("calculating decision path wrong!") 178 | print(e) 179 | res = 0 180 | 181 | return res,len(paths1),len(paths2) 182 | 183 | 184 | # 计算三元组抽取的TP,TP+FP,TP+FN 185 | def triplet_extraction(predict_tree, gold_tree): 186 | predict_triplet, gold_triplet = [], [] 187 | for i in range(len(predict_tree)): 188 | for triplet in predict_tree[i]["triples"]: 189 | predict_triplet.append(triplet) 190 | for i in range(len(gold_tree)): 191 | for triplet in gold_tree[i]["triples"]: 192 | gold_triplet.append(triplet) 193 | predict_triplet_num = len(list(set(predict_triplet))) 194 | gold_triplet_num = len(list(set(gold_triplet))) 195 | correct_triplet_num =len(list(set(gold_triplet)&set(predict_triplet))) 196 | return [correct_triplet_num, predict_triplet_num, gold_triplet_num] 197 | 198 | # 计算节点抽取的TP,TP+FP,TP+FN 199 | def node_extraction(predict_tree, gold_tree): 200 | predict_node, gold_node = [], [] 201 | for i in range(len(predict_tree)): 202 | if len(predict_tree[i]['triples'])>0: 203 | predict_node.append(predict_tree[i]) 204 | for i in range(len(gold_tree)): 205 | if len(gold_tree[i]['triples']) > 0: 206 | gold_node.append(gold_tree[i]) 207 | 208 | predict_triplet_num = len(predict_node) 209 | gold_triplet_num = len(gold_node) 210 | correct_triplet_num = 0 211 | for node1 in predict_node: 212 | for node2 in gold_node: 213 | if len(node1['triples'])>0 and node1['role'] == node2['role'] and node1['logical_rel'] == node2['logical_rel'] and set(node1['triples']) == set(node2['triples']): 214 | correct_triplet_num +=1 215 | return [correct_triplet_num, predict_triplet_num, gold_triplet_num] 216 | 217 | #评测函数,共计算5个指标: 三元组抽取的F1;节点抽取的F1;决策树的Acc;决策路径的F1; 树的编辑距离 218 | def text2dt_eval_single_tree(predict_tree, gold_tree): 219 | # 将符合诊疗决策树的节点前序序列转化为代表诊疗决策树结构的节点矩阵,matrix[i][j]='F'/'L'/'R'表示第j个节点是第i个节点的父/左子/右子节点 220 | for node in predict_tree: 221 | for i in range(len(node['triples'])): 222 | print(node['triples'][i]) 223 | assert len(node['triples'][i]) == 3, "the triple format is wrong" 224 | node['triples'][i]=(node['triples'][i][0].lower(), node['triples'][i][1].lower(), node['triples'][i][2].lower()) 225 | for node in gold_tree: 226 | for i in range(len(node['triples'])): 227 | assert len(node['triples'][i]) == 3, "the triple format is wrong" 228 | node['triples'][i]=(node['triples'][i][0].lower(), node['triples'][i][1].lower(), node['triples'][i][2].lower()) 229 | 230 | # print("step1: ") 231 | predict_matrix = nodematrix(predict_tree) 232 | gold_matrix = nodematrix(gold_tree) 233 | 234 | # 用于计算生成树的Acc 235 | tree_num = (0 if predict_tree == [] else 1) 236 | correct_tree_num = is_tree_equal(predict_tree,gold_tree) 237 | 238 | # 用于计算triplet抽取的F1 239 | correct_triplet_num, predict_triplet_num, gold_triplet_num = triplet_extraction(predict_tree, gold_tree) 240 | 241 | # 用于计算决策路径的F1 242 | # print("step2: ") 243 | correct_path_num, predict_path_num, gold_path_num = decision_path( 244 | copy.deepcopy(predict_tree), 245 | copy.deepcopy(gold_tree), 246 | copy.deepcopy(predict_matrix), 247 | copy.deepcopy(gold_matrix) 248 | ) 249 | # print("correct_path_num: ", correct_path_num) 250 | 251 | # 用于计算树的编辑距离 252 | edit_dis = edit_distance(predict_tree, gold_tree, predict_matrix, gold_matrix) 253 | 254 | correct_node_num, predict_node_num, gold_node_num = node_extraction(predict_tree, gold_tree) 255 | 256 | return tree_num,correct_tree_num, correct_triplet_num, predict_triplet_num, gold_triplet_num, correct_path_num, predict_path_num, gold_path_num, edit_dis, correct_node_num, predict_node_num, gold_node_num 257 | 258 | 259 | 260 | 261 | 262 | -------------------------------------------------------------------------------- /src/ft_chatglm_lora/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class ModelArguments: 7 | """ 8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 9 | """ 10 | 11 | model_name_or_path: str = field( 12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 13 | ) 14 | ptuning_checkpoint: str = field( 15 | default=None, metadata={"help": "Path to p-tuning v2 checkpoints"} 16 | ) 17 | config_name: Optional[str] = field( 18 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 19 | ) 20 | tokenizer_name: Optional[str] = field( 21 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 22 | ) 23 | cache_dir: Optional[str] = field( 24 | default=None, 25 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 26 | ) 27 | use_fast_tokenizer: bool = field( 28 | default=True, 29 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 30 | ) 31 | model_revision: str = field( 32 | default="main", 33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 34 | ) 35 | use_auth_token: bool = field( 36 | default=False, 37 | metadata={ 38 | "help": ( 39 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 40 | "with private models)." 41 | ) 42 | }, 43 | ) 44 | resize_position_embeddings: Optional[bool] = field( 45 | default=None, 46 | metadata={ 47 | "help": ( 48 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 49 | "the model's position embeddings." 50 | ) 51 | }, 52 | ) 53 | quantization_bit: Optional[int] = field( 54 | default=None 55 | ) 56 | pre_seq_len: Optional[int] = field( 57 | default=None 58 | ) 59 | prefix_projection: bool = field( 60 | default=False 61 | ) 62 | 63 | trainable: Optional[str] = field(default="q_proj,v_proj") 64 | lora_rank: Optional[int] = field(default=8) 65 | lora_dropout: Optional[float] = field(default=0.1) 66 | lora_alpha: Optional[float] = field(default=32.) 67 | modules_to_save: Optional[str] = field(default='embed_tokens,lm_head') 68 | debug_mode: Optional[bool] = field(default=False) 69 | peft_path: Optional[str] = field(default=None) 70 | 71 | 72 | @dataclass 73 | class DataTrainingArguments: 74 | """ 75 | Arguments pertaining to what data we are going to input our model for training and eval. 76 | """ 77 | 78 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) 79 | 80 | dataset_name: Optional[str] = field( 81 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 82 | ) 83 | dataset_config_name: Optional[str] = field( 84 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 85 | ) 86 | prompt_column: Optional[str] = field( 87 | default=None, 88 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 89 | ) 90 | response_column: Optional[str] = field( 91 | default=None, 92 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 93 | ) 94 | history_column: Optional[str] = field( 95 | default=None, 96 | metadata={"help": "The name of the column in the datasets containing the history of chat."}, 97 | ) 98 | train_file: Optional[str] = field( 99 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 100 | ) 101 | validation_file: Optional[str] = field( 102 | default=None, 103 | metadata={ 104 | "help": ( 105 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 106 | ) 107 | }, 108 | ) 109 | test_file: Optional[str] = field( 110 | default=None, 111 | metadata={ 112 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 113 | }, 114 | ) 115 | overwrite_cache: bool = field( 116 | default=True, metadata={"help": "Overwrite the cached training and evaluation sets"} 117 | ) 118 | preprocessing_num_workers: Optional[int] = field( 119 | default=None, 120 | metadata={"help": "The number of processes to use for the preprocessing."}, 121 | ) 122 | max_source_length: Optional[int] = field( 123 | default=1024, 124 | metadata={ 125 | "help": ( 126 | "The maximum total input sequence length after tokenization. Sequences longer " 127 | "than this will be truncated, sequences shorter will be padded." 128 | ) 129 | }, 130 | ) 131 | max_target_length: Optional[int] = field( 132 | default=128, 133 | metadata={ 134 | "help": ( 135 | "The maximum total sequence length for target text after tokenization. Sequences longer " 136 | "than this will be truncated, sequences shorter will be padded." 137 | ) 138 | }, 139 | ) 140 | val_max_target_length: Optional[int] = field( 141 | default=None, 142 | metadata={ 143 | "help": ( 144 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 145 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 146 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 147 | "during ``evaluate`` and ``predict``." 148 | ) 149 | }, 150 | ) 151 | pad_to_max_length: bool = field( 152 | default=False, 153 | metadata={ 154 | "help": ( 155 | "Whether to pad all samples to model maximum sentence length. " 156 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 157 | "efficient on GPU but very bad for TPU." 158 | ) 159 | }, 160 | ) 161 | max_train_samples: Optional[int] = field( 162 | default=None, 163 | metadata={ 164 | "help": ( 165 | "For debugging purposes or quicker training, truncate the number of training examples to this " 166 | "value if set." 167 | ) 168 | }, 169 | ) 170 | max_eval_samples: Optional[int] = field( 171 | default=None, 172 | metadata={ 173 | "help": ( 174 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 175 | "value if set." 176 | ) 177 | }, 178 | ) 179 | max_predict_samples: Optional[int] = field( 180 | default=None, 181 | metadata={ 182 | "help": ( 183 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 184 | "value if set." 185 | ) 186 | }, 187 | ) 188 | num_beams: Optional[int] = field( 189 | default=None, 190 | metadata={ 191 | "help": ( 192 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 193 | "which is used during ``evaluate`` and ``predict``." 194 | ) 195 | }, 196 | ) 197 | ignore_pad_token_for_loss: bool = field( 198 | default=True, 199 | metadata={ 200 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 201 | }, 202 | ) 203 | source_prefix: Optional[str] = field( 204 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 205 | ) 206 | 207 | forced_bos_token: Optional[str] = field( 208 | default=None, 209 | metadata={ 210 | "help": ( 211 | "The token to force as the first generated token after the decoder_start_token_id." 212 | "Useful for multilingual models like mBART where the first generated token" 213 | "needs to be the target language token (Usually it is the target language token)" 214 | ) 215 | }, 216 | ) 217 | 218 | 219 | 220 | def __post_init__(self): 221 | if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None: 222 | raise ValueError("Need either a dataset name or a training/validation/test file.") 223 | else: 224 | if self.train_file is not None: 225 | extension = self.train_file.split(".")[-1] 226 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 227 | if self.validation_file is not None: 228 | extension = self.validation_file.split(".")[-1] 229 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 230 | if self.val_max_target_length is None: 231 | self.val_max_target_length = self.max_target_length 232 | 233 | -------------------------------------------------------------------------------- /src/ft_chatglm_lora/evaluate.sh: -------------------------------------------------------------------------------- 1 | lora_rank=8 2 | lora_trainable="query_key_value,dense,dense_h_to_4h,dense_4h_to_h" 3 | modules_to_save="null" 4 | lora_dropout=0.1 5 | LR=2e-4 6 | model_name_or_path="./models--THUDM--chatglm-6b/snapshots/a8ede826cf1b62bd3c78bdfb3625c7c5d2048fbd" # LLM底座模型路径,或者是huggingface hub上的模型名称 7 | your_data_path="./datasets/PromptCBLUE/toy_examples" # 填入数据集所在的文件夹路径 8 | CHECKPOINT="./experiments/outputs/PromptCBLUE-chatglm-6b-lora-2e-4" # 填入用来存储模型的文件夹路径 9 | 10 | STEP=10 # 用来评估的模型checkpoint是训练了多少步 11 | 12 | CUDA_VISIBLE_DEVICES=3 python src/ft_chatglm_lora/main.py \ 13 | --do_predict \ 14 | --validation_file $your_data_path/dev.json \ 15 | --test_file $your_data_path/test.json \ 16 | --cache_dir $your_data_path \ 17 | --overwrite_cache \ 18 | --prompt_column input \ 19 | --response_column target \ 20 | --model_name_or_path $model_name_or_path \ 21 | --peft_path $CHECKPOINT/checkpoint-$STEP \ 22 | --output_dir $CHECKPOINT/checkpoint-$STEP \ 23 | --overwrite_output_dir \ 24 | --max_source_length 828 \ 25 | --max_target_length 196 \ 26 | --per_device_eval_batch_size 8 \ 27 | --predict_with_generate 28 | 29 | # --do_eval \ -------------------------------------------------------------------------------- /src/ft_chatglm_lora/train.sh: -------------------------------------------------------------------------------- 1 | lora_rank=8 2 | lora_trainable="query_key_value,dense,dense_h_to_4h,dense_4h_to_h" 3 | modules_to_save="null" 4 | lora_dropout=0.1 5 | LR=2e-4 6 | model_name_or_path="./models--THUDM--chatglm-6b/snapshots/a8ede826cf1b62bd3c78bdfb3625c7c5d2048fbd" # LLM底座模型路径,或者是huggingface hub上的模型名称 7 | your_data_path="./datasets/PromptCBLUE/toy_examples" # 填入数据集所在的文件夹路径 8 | your_checkpopint_path="./experiments/outputs/" # 填入用来存储模型的路径 9 | 10 | peft_path="" # 如果之前训练过,且存储了peft权重,则设置为peft权重的文件夹路径 11 | 12 | CUDA_VISIBLE_DEVICES=1 python src/ft_chatglm_lora/main.py \ 13 | --do_train \ 14 | --train_file $your_data_path/train.json \ 15 | --validation_file $your_data_path/dev.json \ 16 | --cache_dir $your_data_path \ 17 | --prompt_column input \ 18 | --response_column target \ 19 | --overwrite_cache \ 20 | --model_name_or_path $model_name_or_path \ 21 | --output_dir $your_checkpopint_path/PromptCBLUE-chatglm-6b-lora-$LR \ 22 | --overwrite_output_dir \ 23 | --max_source_length 828 \ 24 | --max_target_length 196 \ 25 | --per_device_train_batch_size 8 \ 26 | --per_device_eval_batch_size 4 \ 27 | --gradient_accumulation_steps 2 \ 28 | --max_steps 10000 \ 29 | --logging_steps 10 \ 30 | --save_steps 10 \ 31 | --learning_rate $LR \ 32 | --lora_rank ${lora_rank} \ 33 | --trainable ${lora_trainable} \ 34 | --modules_to_save ${modules_to_save} \ 35 | --lora_dropout ${lora_dropout} 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /src/ft_chatglm_lora/trainer_seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Dict, List, Optional, Tuple, Union 16 | 17 | import torch 18 | from torch import nn 19 | from torch.utils.data import Dataset 20 | 21 | from transformers.deepspeed import is_deepspeed_zero3_enabled 22 | 23 | from transformers.trainer_utils import PredictionOutput 24 | from transformers.utils import logging 25 | 26 | 27 | from .trainer import Trainer 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | 32 | class Seq2SeqTrainer(Trainer): 33 | def evaluate( 34 | self, 35 | eval_dataset: Optional[Dataset] = None, 36 | ignore_keys: Optional[List[str]] = None, 37 | metric_key_prefix: str = "eval", 38 | **gen_kwargs 39 | ) -> Dict[str, float]: 40 | """ 41 | Run evaluation and returns metrics. 42 | 43 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent 44 | (pass it to the init `compute_metrics` argument). 45 | 46 | You can also subclass and override this method to inject custom behavior. 47 | 48 | Args: 49 | eval_dataset (`Dataset`, *optional*): 50 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns 51 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` 52 | method. 53 | ignore_keys (`List[str]`, *optional*): 54 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 55 | gathering predictions. 56 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 57 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 58 | "eval_bleu" if the prefix is `"eval"` (default) 59 | max_length (`int`, *optional*): 60 | The maximum target length to use when predicting with the generate method. 61 | num_beams (`int`, *optional*): 62 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 63 | beam search. 64 | gen_kwargs: 65 | Additional `generate` specific kwargs. 66 | 67 | Returns: 68 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The 69 | dictionary also contains the epoch number which comes from the training state. 70 | """ 71 | 72 | gen_kwargs = gen_kwargs.copy() 73 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 74 | gen_kwargs["max_length"] = self.args.generation_max_length 75 | gen_kwargs["num_beams"] = ( 76 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams 77 | ) 78 | self._gen_kwargs = gen_kwargs 79 | 80 | return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 81 | 82 | def predict( 83 | self, 84 | test_dataset: Dataset, 85 | ignore_keys: Optional[List[str]] = None, 86 | metric_key_prefix: str = "test", 87 | **gen_kwargs 88 | ) -> PredictionOutput: 89 | """ 90 | Run prediction and returns predictions and potential metrics. 91 | 92 | Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method 93 | will also return metrics, like in `evaluate()`. 94 | 95 | Args: 96 | test_dataset (`Dataset`): 97 | Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the 98 | `model.forward()` method are automatically removed. Has to implement the method `__len__` 99 | ignore_keys (`List[str]`, *optional*): 100 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 101 | gathering predictions. 102 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 103 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 104 | "eval_bleu" if the prefix is `"eval"` (default) 105 | max_length (`int`, *optional*): 106 | The maximum target length to use when predicting with the generate method. 107 | num_beams (`int`, *optional*): 108 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 109 | beam search. 110 | gen_kwargs: 111 | Additional `generate` specific kwargs. 112 | 113 | 114 | 115 | If your predictions or labels have different sequence lengths (for instance because you're doing dynamic 116 | padding in a token classification task) the predictions will be padded (on the right) to allow for 117 | concatenation into one array. The padding index is -100. 118 | 119 | 120 | 121 | Returns: *NamedTuple* A namedtuple with the following keys: 122 | 123 | - predictions (`np.ndarray`): The predictions on `test_dataset`. 124 | - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). 125 | - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained 126 | labels). 127 | """ 128 | 129 | gen_kwargs = gen_kwargs.copy() 130 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 131 | gen_kwargs["max_length"] = self.args.generation_max_length 132 | gen_kwargs["num_beams"] = ( 133 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams 134 | ) 135 | self._gen_kwargs = gen_kwargs 136 | 137 | 138 | return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 139 | 140 | def prediction_step( 141 | self, 142 | model: nn.Module, 143 | inputs: Dict[str, Union[torch.Tensor, Any]], 144 | prediction_loss_only: bool, 145 | ignore_keys: Optional[List[str]] = None, 146 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 147 | """ 148 | Perform an evaluation step on `model` using `inputs`. 149 | 150 | Subclass and override to inject custom behavior. 151 | 152 | Args: 153 | model (`nn.Module`): 154 | The model to evaluate. 155 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 156 | The inputs and targets of the model. 157 | 158 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 159 | argument `labels`. Check your model's documentation for all accepted arguments. 160 | prediction_loss_only (`bool`): 161 | Whether or not to return the loss only. 162 | 163 | Return: 164 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and 165 | labels (each being optional). 166 | """ 167 | 168 | if not self.args.predict_with_generate or prediction_loss_only: 169 | return super().prediction_step( 170 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 171 | ) 172 | 173 | has_labels = "labels" in inputs 174 | inputs = self._prepare_inputs(inputs) 175 | 176 | # XXX: adapt synced_gpus for fairscale as well 177 | gen_kwargs = self._gen_kwargs.copy() 178 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 179 | gen_kwargs["max_length"] = self.model.config.max_length 180 | gen_kwargs["num_beams"] = ( 181 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams 182 | ) 183 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False 184 | gen_kwargs["synced_gpus"] = ( 185 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus 186 | ) 187 | 188 | if "attention_mask" in inputs: 189 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) 190 | if "position_ids" in inputs: 191 | gen_kwargs["position_ids"] = inputs.get("position_ids", None) 192 | if "global_attention_mask" in inputs: 193 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) 194 | 195 | # prepare generation inputs 196 | # some encoder-decoder models can have varying encoder's and thus 197 | # varying model input names 198 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: 199 | generation_inputs = inputs[self.model.encoder.main_input_name] 200 | else: 201 | generation_inputs = inputs[self.model.main_input_name] 202 | 203 | gen_kwargs["input_ids"] = generation_inputs 204 | generated_tokens = self.model.generate(**gen_kwargs) 205 | generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] 206 | 207 | # in case the batch is shorter than max length, the output should be padded 208 | if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]: 209 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 210 | elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < ( 211 | gen_kwargs["max_new_tokens"] + 1 212 | ): 213 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) 214 | 215 | loss = None 216 | 217 | if self.args.prediction_loss_only: 218 | return (loss, None, None) 219 | 220 | if has_labels: 221 | labels = inputs["labels"] 222 | if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]: 223 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) 224 | elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < ( 225 | gen_kwargs["max_new_tokens"] + 1 226 | ): 227 | labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1)) 228 | else: 229 | labels = None 230 | 231 | return (loss, generated_tokens, labels) 232 | 233 | def _pad_tensors_to_max_len(self, tensor, max_length): 234 | if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): 235 | # If PAD token is not defined at least EOS token has to be defined 236 | pad_token_id = ( 237 | self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id 238 | ) 239 | else: 240 | if self.model.config.pad_token_id is not None: 241 | pad_token_id = self.model.config.pad_token_id 242 | else: 243 | raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") 244 | 245 | padded_tensor = pad_token_id * torch.ones( 246 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device 247 | ) 248 | padded_tensor[:, : tensor.shape[-1]] = tensor 249 | return padded_tensor 250 | -------------------------------------------------------------------------------- /src/ft_chatglm_ptuning/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class ModelArguments: 7 | """ 8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 9 | """ 10 | 11 | model_name_or_path: str = field( 12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 13 | ) 14 | ptuning_checkpoint: str = field( 15 | default=None, metadata={"help": "Path to p-tuning v2 checkpoints"} 16 | ) 17 | config_name: Optional[str] = field( 18 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 19 | ) 20 | tokenizer_name: Optional[str] = field( 21 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 22 | ) 23 | cache_dir: Optional[str] = field( 24 | default=None, 25 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 26 | ) 27 | use_fast_tokenizer: bool = field( 28 | default=True, 29 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 30 | ) 31 | model_revision: str = field( 32 | default="main", 33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 34 | ) 35 | use_auth_token: bool = field( 36 | default=False, 37 | metadata={ 38 | "help": ( 39 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 40 | "with private models)." 41 | ) 42 | }, 43 | ) 44 | resize_position_embeddings: Optional[bool] = field( 45 | default=None, 46 | metadata={ 47 | "help": ( 48 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 49 | "the model's position embeddings." 50 | ) 51 | }, 52 | ) 53 | quantization_bit: Optional[int] = field( 54 | default=None 55 | ) 56 | pre_seq_len: Optional[int] = field( 57 | default=None 58 | ) 59 | prefix_projection: bool = field( 60 | default=False 61 | ) 62 | 63 | 64 | @dataclass 65 | class DataTrainingArguments: 66 | """ 67 | Arguments pertaining to what data we are going to input our model for training and eval. 68 | """ 69 | 70 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) 71 | 72 | dataset_name: Optional[str] = field( 73 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 74 | ) 75 | dataset_config_name: Optional[str] = field( 76 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 77 | ) 78 | prompt_column: Optional[str] = field( 79 | default=None, 80 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 81 | ) 82 | response_column: Optional[str] = field( 83 | default=None, 84 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 85 | ) 86 | history_column: Optional[str] = field( 87 | default=None, 88 | metadata={"help": "The name of the column in the datasets containing the history of chat."}, 89 | ) 90 | train_file: Optional[str] = field( 91 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 92 | ) 93 | validation_file: Optional[str] = field( 94 | default=None, 95 | metadata={ 96 | "help": ( 97 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 98 | ) 99 | }, 100 | ) 101 | test_file: Optional[str] = field( 102 | default=None, 103 | metadata={ 104 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 105 | }, 106 | ) 107 | overwrite_cache: bool = field( 108 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 109 | ) 110 | preprocessing_num_workers: Optional[int] = field( 111 | default=None, 112 | metadata={"help": "The number of processes to use for the preprocessing."}, 113 | ) 114 | max_source_length: Optional[int] = field( 115 | default=1024, 116 | metadata={ 117 | "help": ( 118 | "The maximum total input sequence length after tokenization. Sequences longer " 119 | "than this will be truncated, sequences shorter will be padded." 120 | ) 121 | }, 122 | ) 123 | max_target_length: Optional[int] = field( 124 | default=128, 125 | metadata={ 126 | "help": ( 127 | "The maximum total sequence length for target text after tokenization. Sequences longer " 128 | "than this will be truncated, sequences shorter will be padded." 129 | ) 130 | }, 131 | ) 132 | val_max_target_length: Optional[int] = field( 133 | default=None, 134 | metadata={ 135 | "help": ( 136 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 137 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 138 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 139 | "during ``evaluate`` and ``predict``." 140 | ) 141 | }, 142 | ) 143 | pad_to_max_length: bool = field( 144 | default=False, 145 | metadata={ 146 | "help": ( 147 | "Whether to pad all samples to model maximum sentence length. " 148 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 149 | "efficient on GPU but very bad for TPU." 150 | ) 151 | }, 152 | ) 153 | max_train_samples: Optional[int] = field( 154 | default=None, 155 | metadata={ 156 | "help": ( 157 | "For debugging purposes or quicker training, truncate the number of training examples to this " 158 | "value if set." 159 | ) 160 | }, 161 | ) 162 | max_eval_samples: Optional[int] = field( 163 | default=None, 164 | metadata={ 165 | "help": ( 166 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 167 | "value if set." 168 | ) 169 | }, 170 | ) 171 | max_predict_samples: Optional[int] = field( 172 | default=None, 173 | metadata={ 174 | "help": ( 175 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 176 | "value if set." 177 | ) 178 | }, 179 | ) 180 | num_beams: Optional[int] = field( 181 | default=None, 182 | metadata={ 183 | "help": ( 184 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 185 | "which is used during ``evaluate`` and ``predict``." 186 | ) 187 | }, 188 | ) 189 | ignore_pad_token_for_loss: bool = field( 190 | default=True, 191 | metadata={ 192 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 193 | }, 194 | ) 195 | source_prefix: Optional[str] = field( 196 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 197 | ) 198 | 199 | forced_bos_token: Optional[str] = field( 200 | default=None, 201 | metadata={ 202 | "help": ( 203 | "The token to force as the first generated token after the decoder_start_token_id." 204 | "Useful for multilingual models like mBART where the first generated token" 205 | "needs to be the target language token (Usually it is the target language token)" 206 | ) 207 | }, 208 | ) 209 | 210 | 211 | 212 | def __post_init__(self): 213 | if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None: 214 | raise ValueError("Need either a dataset name or a training/validation/test file.") 215 | else: 216 | if self.train_file is not None: 217 | extension = self.train_file.split(".")[-1] 218 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 219 | if self.validation_file is not None: 220 | extension = self.validation_file.split(".")[-1] 221 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 222 | if self.val_max_target_length is None: 223 | self.val_max_target_length = self.max_target_length 224 | 225 | -------------------------------------------------------------------------------- /src/ft_chatglm_ptuning/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "THUDM/chatglm-6b", 3 | "architectures": [ 4 | "ChatGLMModel" 5 | ], 6 | "auto_map": { 7 | "AutoConfig": "configuration_chatglm.ChatGLMConfig", 8 | "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration", 9 | "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration" 10 | }, 11 | "bos_token_id": 130004, 12 | "eos_token_id": 130005, 13 | "mask_token_id": 130000, 14 | "gmask_token_id": 130001, 15 | "pad_token_id": 3, 16 | "hidden_size": 4096, 17 | "inner_hidden_size": 16384, 18 | "layernorm_epsilon": 1e-05, 19 | "max_sequence_length": 2048, 20 | "model_type": "chatglm", 21 | "num_attention_heads": 32, 22 | "num_layers": 28, 23 | "position_encoding_2d": true, 24 | "torch_dtype": "float16", 25 | "transformers_version": "4.23.1", 26 | "use_cache": true, 27 | "vocab_size": 130528 28 | } 29 | -------------------------------------------------------------------------------- /src/ft_chatglm_ptuning/configuration_chatglm.py: -------------------------------------------------------------------------------- 1 | """ ChatGLM model configuration """ 2 | 3 | from transformers.configuration_utils import PretrainedConfig 4 | from transformers.utils import logging 5 | 6 | logger = logging.get_logger(__name__) 7 | 8 | 9 | class ChatGLMConfig(PretrainedConfig): 10 | r""" 11 | This is the configuration class to store the configuration of a [`~ChatGLMModel`]. 12 | It is used to instantiate an ChatGLM model according to the specified arguments, defining the model 13 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of 14 | the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture. 15 | 16 | Configuration objects inherit from [`PretrainedConfig`] and can be used 17 | to control the model outputs. Read the documentation from [`PretrainedConfig`] 18 | for more information. 19 | 20 | 21 | Args: 22 | vocab_size (`int`, *optional*, defaults to 150528): 23 | Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the 24 | `inputs_ids` passed when calling [`~ChatGLMModel`] or 25 | [`~TFChatGLMModel`]. 26 | hidden_size (`int`, *optional*, defaults to 4096): 27 | Dimension of the encoder layers and the pooler layer. 28 | num_hidden_layers (`int`, *optional*, defaults to 28): 29 | Number of hidden layers in the Transformer encoder. 30 | num_attention_heads (`int`, *optional*, defaults to 32): 31 | Number of attention heads for each attention layer in the Transformer encoder. 32 | inner_hidden_size (`int`, *optional*, defaults to 16384): 33 | Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 34 | max_sequence_length (`int`, *optional*, defaults to 512): 35 | The maximum sequence length that this model might ever be used with. 36 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048). 37 | layernorm_epsilon (`float`, *optional*, defaults to 1e-5): 38 | The epsilon used by the layer normalization layers. 39 | use_cache (`bool`, *optional*, defaults to `True`): 40 | Whether the model should return the last key/values attentions (not used by all models). 41 | Example: 42 | 43 | ```python 44 | >>> from configuration_chatglm import ChatGLMConfig 45 | >>> from modeling_chatglm import ChatGLMModel 46 | 47 | >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration 48 | >>> configuration = ChatGLMConfig() 49 | 50 | >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration 51 | >>> model = ChatGLMModel(configuration) 52 | 53 | >>> # Accessing the model configuration 54 | >>> configuration = model.config 55 | ``` 56 | """ 57 | model_type = "chatglm" 58 | 59 | def __init__( 60 | self, 61 | vocab_size=150528, 62 | hidden_size=4096, 63 | num_layers=28, 64 | num_attention_heads=32, 65 | layernorm_epsilon=1e-5, 66 | use_cache=False, 67 | bos_token_id=150004, 68 | eos_token_id=150005, 69 | mask_token_id=150000, 70 | gmask_token_id=150001, 71 | pad_token_id=0, 72 | max_sequence_length=2048, 73 | inner_hidden_size=16384, 74 | position_encoding_2d=True, 75 | quantization_bit=0, 76 | pre_seq_len=None, 77 | prefix_projection=False, 78 | **kwargs 79 | ): 80 | self.num_layers = num_layers 81 | self.vocab_size = vocab_size 82 | self.hidden_size = hidden_size 83 | self.num_attention_heads = num_attention_heads 84 | self.max_sequence_length = max_sequence_length 85 | self.layernorm_epsilon = layernorm_epsilon 86 | self.inner_hidden_size = inner_hidden_size 87 | self.use_cache = use_cache 88 | self.bos_token_id = bos_token_id 89 | self.eos_token_id = eos_token_id 90 | self.pad_token_id = pad_token_id 91 | self.mask_token_id = mask_token_id 92 | self.gmask_token_id = gmask_token_id 93 | self.position_encoding_2d = position_encoding_2d 94 | self.quantization_bit = quantization_bit 95 | self.pre_seq_len = pre_seq_len 96 | self.prefix_projection = prefix_projection 97 | 98 | super().__init__( 99 | pad_token_id=pad_token_id, 100 | bos_token_id=bos_token_id, 101 | eos_token_id=eos_token_id, 102 | **kwargs 103 | ) 104 | -------------------------------------------------------------------------------- /src/ft_chatglm_ptuning/evaluate.sh: -------------------------------------------------------------------------------- 1 | PRE_SEQ_LEN=128 2 | CHECKPOINT="./experiments/outputs/PromptCBLUE-chatglm-6b-pt-128-2e-2" # 填入用来存储模型的文件夹路径 3 | STEP=10 # 用来评估的模型checkpoint是训练了多少步 4 | 5 | your_data_path="./datasets/PromptCBLUE/toy_examples/" # 填入数据集所在的文件夹路径 6 | model_name_or_path="./models--THUDM--chatglm-6b/snapshots/a8ede826cf1b62bd3c78bdfb3625c7c5d2048fbd" # LLM底座模型路径,或者是huggingface hub上的模型名称 7 | 8 | 9 | CUDA_VISIBLE_DEVICES=1 python src/ft_chatglm_ptuning/main.py \ 10 | --do_predict \ 11 | --do_eval \ 12 | --validation_file $your_data_path/dev.json \ 13 | --test_file $your_data_path/test.json \ 14 | --overwrite_cache \ 15 | --prompt_column input \ 16 | --response_column target \ 17 | --model_name_or_path $model_name_or_path \ 18 | --ptuning_checkpoint $CHECKPOINT/checkpoint-$STEP \ 19 | --output_dir $CHECKPOINT \ 20 | --overwrite_output_dir \ 21 | --max_source_length 700 \ 22 | --max_target_length 196 \ 23 | --per_device_eval_batch_size 1 \ 24 | --predict_with_generate \ 25 | --pre_seq_len $PRE_SEQ_LEN 26 | -------------------------------------------------------------------------------- /src/ft_chatglm_ptuning/test_modeling_chatglm.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import math 3 | import unittest 4 | import torch 5 | import random 6 | 7 | from transformers import AutoTokenizer, AutoModel 8 | from transformers.testing_utils import require_torch, slow, torch_device 9 | 10 | 11 | def set_random_seed(seed): 12 | import random 13 | 14 | random.seed(seed) 15 | 16 | # pytorch RNGs 17 | import torch 18 | 19 | torch.manual_seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | if torch.cuda.is_available(): 22 | torch.cuda.manual_seed_all(seed) 23 | 24 | # numpy RNG 25 | import numpy as np 26 | 27 | np.random.seed(seed) 28 | 29 | 30 | 31 | def ids_tensor(shape, vocab_size): 32 | # Creates a random int32 tensor of the shape within the vocab size 33 | total_dims = 1 34 | for dim in shape: 35 | total_dims *= dim 36 | 37 | values = [] 38 | for _ in range(total_dims): 39 | values.append(random.randint(0, vocab_size - 1)) 40 | 41 | return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() 42 | 43 | 44 | def get_model_and_tokenizer(): 45 | model = AutoModel.from_pretrained("/mnt/vepfs/workspace/zxdu/chatglm_6b", trust_remote_code=True).half() 46 | model.to(torch_device) 47 | model.eval() 48 | tokenizer = AutoTokenizer.from_pretrained("/mnt/vepfs/workspace/zxdu/chatglm_6b", trust_remote_code=True) 49 | return model, tokenizer 50 | 51 | 52 | @require_torch 53 | class ChatGLMGenerationTest(unittest.TestCase): 54 | def get_generation_kwargs(self): 55 | pass 56 | 57 | def test_chat(self): 58 | model, tokenizer = get_model_and_tokenizer() 59 | prompts = ["你好", "介绍一下清华大学", "它创建于哪一年"] 60 | history = [] 61 | set_random_seed(42) 62 | expected_responses = [ 63 | '你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', 64 | '清华大学是中国著名的综合性研究型大学,位于中国北京市海淀区,创建于 1911 年,前身是清华学堂。作为我国顶尖高等教育机构之一,清华大学在科学研究、工程技术、信息技术、经济管理等领域处于领先地位,也是世界上最著名的工程学府之一。\n\n清华大学拥有世界一流的教学设施和科学研究平台,设有多个学院和研究中心,包括工程学院、自然科学学院、社会科学学院、人文学院、法学院、经济管理学院等。学校拥有众多知名教授和研究团队,其中包括多位院士、国家杰出青年科学基金获得者、长江学者等。\n\n清华大学的本科生招生范围为全国中学毕业生,本科生入学要求严格,考试成绩优秀。同时,清华大学也提供研究生和博士生招生,包括硕士研究生和博士研究生。', 65 | '清华大学创建于 1911 年。' 66 | ] 67 | for (prompt, expected_response) in zip(prompts, expected_responses): 68 | response, history = model.chat(tokenizer, prompt, history=history) 69 | print(repr(response)) 70 | self.assertEquals(expected_response, response) 71 | 72 | def test_stream_chat(self): 73 | model, tokenizer = get_model_and_tokenizer() 74 | prompts = ["你好", "介绍一下清华大学", "它创建于哪一年"] 75 | history = [] 76 | expected_responses = [ 77 | '你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', 78 | '清华大学是中国著名的综合性研究型大学,位于中国北京市海淀区,创建于 1911 年,前身是清华学堂。作为我国顶尖高等教育机构之一,清华大学在科学研究、工程技术、信息技术、经济管理等领域处于领先地位,也是世界上最著名的工程学府之一。\n\n清华大学拥有世界一流的教学设施和科学研究平台,设有多个学院和研究中心,包括工程学院、自然科学学院、社会科学学院、人文学院、法学院、经济管理学院等。学校拥有众多知名教授和研究团队,其中包括多位院士、国家杰出青年科学基金获得者、长江学者等。\n\n清华大学的本科生招生范围为全国中学毕业生,本科生入学要求严格,考试成绩优秀。同时,清华大学也提供研究生和博士生招生,包括硕士研究生和博士研究生。', 79 | '清华大学创建于 1911 年。' 80 | ] 81 | set_random_seed(42) 82 | for prompt, expected_response in zip(prompts, expected_responses): 83 | response = "" 84 | for idx, (response, history) in enumerate(model.stream_chat(tokenizer, prompt, history=history)): 85 | pass 86 | print(repr(response)) 87 | self.assertEquals(expected_response, response) 88 | 89 | def test_generation(self): 90 | model, tokenizer = get_model_and_tokenizer() 91 | sentence = "晚上睡不着怎么办" 92 | parameters = [(False, 2048, 1), 93 | (False, 64, 1), 94 | (True, 2048, 1), 95 | (True, 64, 1), 96 | (True, 2048, 4)] 97 | expected_out_sentences = [ 98 | '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。\n\n3. 避免刺激性物质:避免饮用含咖啡因的饮料,如咖啡、茶和可乐,并尽可能减少饮酒。\n\n4. 放松身心:尝试进行放松的活动,如冥想、深呼吸、瑜伽或听轻柔的音乐。\n\n5. 避免在床上做其他事情:例如看电视、使用电脑或智能手机等。\n\n6. 练习放松技巧:例如渐进性肌肉松弛法、冥想或深呼吸练习。\n\n7. 寻求帮助:如果长时间都无法正常入睡,可以考虑咨询医生或专业心理医生,寻求更进一步的帮助。\n\n希望这些方法能有助于入睡。', 99 | '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。', 100 | '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体释放褪黑素,进而导致难以入睡。建议你在睡前一小时停止使用这些设备。\n\n3. 创建舒适的睡眠环境:确保卧室安静、黑暗、凉爽,舒适的床垫和枕头,保持卧室温度适宜,这有助于让你更容易入睡。\n\n4. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽或轻松的散步,减轻压力和焦虑,让你更容易入睡。\n\n5. 避免咖啡因和酒精:咖啡因和酒精会让大脑更加兴奋,进而干扰身体入睡过程。建议在睡前几小时避免饮用这些物质。\n\n6. 做一些安静的活动:阅读一本书、听轻柔的音乐、绣或者绘画等安静的活动,有助于自己放松身心,进而更容易入睡。\n\n如果采取以上这些方法仍然无法入睡,建议咨询医生或专业的睡眠专家,获取更好的建议和帮助。', 101 | '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体', 102 | '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 建立规律的睡眠时间表:尽量在同一时间入睡和起床,即使在周末和假期也要尽量保持一致。\n\n2. 创造舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,使用舒适的床垫和枕头等。\n\n3. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽、听轻柔的音乐等,缓解压力和紧张情绪。\n\n4. 避免刺激性物质:避免饮用咖啡、茶、可乐等含咖啡因的饮料,避免吸烟和饮酒等刺激性物质。\n\n5. 避免躺在床上翻来覆去:如果躺在床上超过20分钟还不能入睡,就不要躺在床上翻来覆去,而是起床去做一些放松的活动,直到感到困倦为止。\n\n6. 练习放松技巧:如果感到焦虑或紧张,可以尝试进行一些放松技巧,如渐进性肌肉松弛、冥想等。\n\n7. 改善睡眠障碍:如果已经尝试了上述方法仍然无法入睡,可以考虑咨询医生,了解是否存在其他睡眠障碍问题,并接受相应的治疗。'] 103 | for (do_sample, max_length, num_beams), expected_output_sentence in zip(parameters, expected_out_sentences): 104 | set_random_seed(42) 105 | inputs = tokenizer(sentence, return_tensors="pt") 106 | inputs = inputs.to(torch_device) 107 | 108 | outputs = model.generate( 109 | **inputs, 110 | do_sample=do_sample, 111 | max_length=max_length, 112 | num_beams=num_beams 113 | ) 114 | 115 | outputs = outputs.tolist()[0] 116 | out_sentence = tokenizer.decode(outputs, skip_special_tokens=True) 117 | print(out_sentence) 118 | self.assertEquals(expected_output_sentence, out_sentence) 119 | 120 | def test_batch_generation(self): 121 | model, tokenizer = get_model_and_tokenizer() 122 | sentences = [ 123 | "你好", 124 | "介绍一下清华大学" 125 | ] 126 | parameters = [(False, 2048, 1), 127 | (False, 64, 1), 128 | (True, 2048, 1), 129 | (True, 64, 1), 130 | (True, 2048, 4)] 131 | expected_out_sentences = [ 132 | ['你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', 133 | '介绍一下清华大学 清华大学是中国著名的综合性大学,位于北京市海淀区双清路30号,其历史可以追溯到1911年创建的清华学堂,1925年更名为清华学校,1937年抗日战争全面爆发后南迁长沙,1946年迁回清华园。新中国成立后,清华学校更名为清华大学。\n\n清华大学是中国最顶尖的大学之一,在工程、科学、技术、经济、管理等领域都有很高的学术声誉和影响力。学校拥有世界一流的教学设施和科学研究平台,有多个学院和研究中心,包括工程学院、自然科学学院、人文学院、社会科学学院、经济管理学院、法学院、美术学院、医学院、器学院等。\n\n清华大学的本科生招生始于2000年,实行全面二孩政策后,本科生招生规模不断扩大。截至2022年,清华大学共有本科生近3万人,研究生近2万人,其中国际学生占比约为10%。清华大学的本科生教育注重通识教育和个性化培养,强调实践、创新、国际化和综合素质。'], 134 | [ 135 | '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', 136 | '介绍一下清华大学 清华大学是中国著名的综合性大学,位于北京市海淀区双清路30号,其历史可以追溯到1911年创建的清华学堂,1925年更名为清华学校,1937年抗日战争全面爆发后南迁长沙,1946年迁回' 137 | ], 138 | [ 139 | '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', 140 | '介绍一下清华大学 清华大学是中国著名的综合性研究型大学,位于北京市海淀区双清路 30 号,其溯源于 1911 年创建的清华学堂, 1925 年更名为清华学校, 1937 年秋抗日战争全面爆发后闭校。1949 年 10 月开学复校,成为我国第一个社会主义大学生活了的高校。截至 2023 年,清华学校共管辖 2 个学院、13 个系,有本科专业 60 个,研究生专业 190 个。' 141 | ], 142 | [ 143 | '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', 144 | '介绍一下清华大学 清华大学是中国著名的综合性研究型大学,位于北京市海淀区双清路 30 号,其溯源于 1911 年创建的清华学堂, 1925 年更名为清华学校, 1937 年秋抗日战争全面爆发后' 145 | ], 146 | [ 147 | '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', 148 | '介绍一下清华大学 清华大学是中国著名的综合性研究型大学,位于北京市海淀区双清路30号,其历史可以追溯到1911年创建的清华学堂,1925年更名为清华学校,1937年抗日战争全面爆发后南迁长沙,与北京大学、南开大学组建国立长沙临时大学,1938年迁至 昆明改名为国立西南联合大学,1946年迁回北京。新中国成立后,清华学校更名为清华大学。' 149 | ] 150 | ] 151 | for (do_sample, max_length, num_beams), expected_output_sentence in zip(parameters, expected_out_sentences): 152 | set_random_seed(42) 153 | inputs = tokenizer(sentences, return_tensors="pt", padding=True) 154 | inputs = inputs.to(torch_device) 155 | 156 | outputs = model.generate( 157 | **inputs, 158 | do_sample=do_sample, 159 | max_length=max_length, 160 | num_beams=num_beams 161 | ) 162 | 163 | batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) 164 | print(batch_out_sentence) 165 | self.assertListEqual(expected_output_sentence, batch_out_sentence) 166 | -------------------------------------------------------------------------------- /src/ft_chatglm_ptuning/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name_or_path": "THUDM/chatglm-6b", 3 | "bos_token": "", 4 | "eos_token": "", 5 | "end_token": "", 6 | "gmask_token": "[gMASK]", 7 | "mask_token": "[MASK]", 8 | "pad_token": "", 9 | "unk_token": "", 10 | "remove_space": false, 11 | "do_lower_case": false, 12 | "tokenizer_class": "ChatGLMTokenizer", 13 | "num_image_tokens": 0, 14 | "auto_map": { 15 | "AutoTokenizer": [ 16 | "tokenization_chatglm.ChatGLMTokenizer", 17 | null 18 | ] 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/ft_chatglm_ptuning/train.sh: -------------------------------------------------------------------------------- 1 | PRE_SEQ_LEN=128 2 | LR=2e-2 3 | your_data_path="./datasets/PromptCBLUE/toy_examples/" # 填入数据集所在的文件夹路径 4 | your_checkpopint_path="./experiments/outputs/" # 填入用来存储模型的路径 5 | model_name_or_path="./models--THUDM--chatglm-6b/snapshots/a8ede826cf1b62bd3c78bdfb3625c7c5d2048fbd" # LLM底座模型路径,或者是huggingface hub上的模型名称 6 | 7 | 8 | CUDA_VISIBLE_DEVICES=1 python src/ft_chatglm_ptuning/main.py \ 9 | --do_train \ 10 | --train_file $your_data_path/train.json \ 11 | --validation_file $your_data_path/dev.json \ 12 | --prompt_column input \ 13 | --response_column target \ 14 | --overwrite_cache \ 15 | --model_name_or_path $model_name_or_path \ 16 | --output_dir $your_checkpopint_path/PromptCBLUE-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \ 17 | --overwrite_output_dir \ 18 | --max_source_length 700 \ 19 | --max_target_length 196 \ 20 | --per_device_train_batch_size 8 \ 21 | --per_device_eval_batch_size 8 \ 22 | --gradient_accumulation_steps 2 \ 23 | --max_steps 10000 \ 24 | --logging_steps 10 \ 25 | --save_steps 10 \ 26 | --learning_rate $LR \ 27 | --pre_seq_len $PRE_SEQ_LEN 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /src/ft_chatglm_ptuning/trainer_seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Dict, List, Optional, Tuple, Union 16 | 17 | import torch 18 | from torch import nn 19 | from torch.utils.data import Dataset 20 | 21 | from transformers.deepspeed import is_deepspeed_zero3_enabled 22 | 23 | from transformers.trainer_utils import PredictionOutput 24 | from transformers.utils import logging 25 | 26 | 27 | from .trainer import Trainer 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | 32 | class Seq2SeqTrainer(Trainer): 33 | def evaluate( 34 | self, 35 | eval_dataset: Optional[Dataset] = None, 36 | ignore_keys: Optional[List[str]] = None, 37 | metric_key_prefix: str = "eval", 38 | **gen_kwargs 39 | ) -> Dict[str, float]: 40 | """ 41 | Run evaluation and returns metrics. 42 | 43 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent 44 | (pass it to the init `compute_metrics` argument). 45 | 46 | You can also subclass and override this method to inject custom behavior. 47 | 48 | Args: 49 | eval_dataset (`Dataset`, *optional*): 50 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns 51 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` 52 | method. 53 | ignore_keys (`List[str]`, *optional*): 54 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 55 | gathering predictions. 56 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 57 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 58 | "eval_bleu" if the prefix is `"eval"` (default) 59 | max_length (`int`, *optional*): 60 | The maximum target length to use when predicting with the generate method. 61 | num_beams (`int`, *optional*): 62 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 63 | beam search. 64 | gen_kwargs: 65 | Additional `generate` specific kwargs. 66 | 67 | Returns: 68 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The 69 | dictionary also contains the epoch number which comes from the training state. 70 | """ 71 | 72 | gen_kwargs = gen_kwargs.copy() 73 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 74 | gen_kwargs["max_length"] = self.args.generation_max_length 75 | gen_kwargs["num_beams"] = ( 76 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams 77 | ) 78 | self._gen_kwargs = gen_kwargs 79 | 80 | return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 81 | 82 | def predict( 83 | self, 84 | test_dataset: Dataset, 85 | ignore_keys: Optional[List[str]] = None, 86 | metric_key_prefix: str = "test", 87 | **gen_kwargs 88 | ) -> PredictionOutput: 89 | """ 90 | Run prediction and returns predictions and potential metrics. 91 | 92 | Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method 93 | will also return metrics, like in `evaluate()`. 94 | 95 | Args: 96 | test_dataset (`Dataset`): 97 | Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the 98 | `model.forward()` method are automatically removed. Has to implement the method `__len__` 99 | ignore_keys (`List[str]`, *optional*): 100 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 101 | gathering predictions. 102 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 103 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 104 | "eval_bleu" if the prefix is `"eval"` (default) 105 | max_length (`int`, *optional*): 106 | The maximum target length to use when predicting with the generate method. 107 | num_beams (`int`, *optional*): 108 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 109 | beam search. 110 | gen_kwargs: 111 | Additional `generate` specific kwargs. 112 | 113 | 114 | 115 | If your predictions or labels have different sequence lengths (for instance because you're doing dynamic 116 | padding in a token classification task) the predictions will be padded (on the right) to allow for 117 | concatenation into one array. The padding index is -100. 118 | 119 | 120 | 121 | Returns: *NamedTuple* A namedtuple with the following keys: 122 | 123 | - predictions (`np.ndarray`): The predictions on `test_dataset`. 124 | - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). 125 | - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained 126 | labels). 127 | """ 128 | 129 | gen_kwargs = gen_kwargs.copy() 130 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 131 | gen_kwargs["max_length"] = self.args.generation_max_length 132 | gen_kwargs["num_beams"] = ( 133 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams 134 | ) 135 | self._gen_kwargs = gen_kwargs 136 | 137 | 138 | return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 139 | 140 | def prediction_step( 141 | self, 142 | model: nn.Module, 143 | inputs: Dict[str, Union[torch.Tensor, Any]], 144 | prediction_loss_only: bool, 145 | ignore_keys: Optional[List[str]] = None, 146 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 147 | """ 148 | Perform an evaluation step on `model` using `inputs`. 149 | 150 | Subclass and override to inject custom behavior. 151 | 152 | Args: 153 | model (`nn.Module`): 154 | The model to evaluate. 155 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 156 | The inputs and targets of the model. 157 | 158 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 159 | argument `labels`. Check your model's documentation for all accepted arguments. 160 | prediction_loss_only (`bool`): 161 | Whether or not to return the loss only. 162 | 163 | Return: 164 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and 165 | labels (each being optional). 166 | """ 167 | 168 | if not self.args.predict_with_generate or prediction_loss_only: 169 | return super().prediction_step( 170 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 171 | ) 172 | 173 | has_labels = "labels" in inputs 174 | inputs = self._prepare_inputs(inputs) 175 | 176 | # XXX: adapt synced_gpus for fairscale as well 177 | gen_kwargs = self._gen_kwargs.copy() 178 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 179 | gen_kwargs["max_length"] = self.model.config.max_length 180 | gen_kwargs["num_beams"] = ( 181 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams 182 | ) 183 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False 184 | gen_kwargs["synced_gpus"] = ( 185 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus 186 | ) 187 | 188 | if "attention_mask" in inputs: 189 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) 190 | if "position_ids" in inputs: 191 | gen_kwargs["position_ids"] = inputs.get("position_ids", None) 192 | if "global_attention_mask" in inputs: 193 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) 194 | 195 | # prepare generation inputs 196 | # some encoder-decoder models can have varying encoder's and thus 197 | # varying model input names 198 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: 199 | generation_inputs = inputs[self.model.encoder.main_input_name] 200 | else: 201 | generation_inputs = inputs[self.model.main_input_name] 202 | 203 | gen_kwargs["input_ids"] = generation_inputs 204 | generated_tokens = self.model.generate(**gen_kwargs) 205 | generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] 206 | 207 | # in case the batch is shorter than max length, the output should be padded 208 | if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]: 209 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 210 | elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < ( 211 | gen_kwargs["max_new_tokens"] + 1 212 | ): 213 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) 214 | 215 | loss = None 216 | 217 | if self.args.prediction_loss_only: 218 | return (loss, None, None) 219 | 220 | if has_labels: 221 | labels = inputs["labels"] 222 | if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]: 223 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) 224 | elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < ( 225 | gen_kwargs["max_new_tokens"] + 1 226 | ): 227 | labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1)) 228 | else: 229 | labels = None 230 | 231 | return (loss, generated_tokens, labels) 232 | 233 | def _pad_tensors_to_max_len(self, tensor, max_length): 234 | if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): 235 | # If PAD token is not defined at least EOS token has to be defined 236 | pad_token_id = ( 237 | self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id 238 | ) 239 | else: 240 | if self.model.config.pad_token_id is not None: 241 | pad_token_id = self.model.config.pad_token_id 242 | else: 243 | raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") 244 | 245 | padded_tensor = pad_token_id * torch.ones( 246 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device 247 | ) 248 | padded_tensor[:, : tensor.shape[-1]] = tensor 249 | return padded_tensor 250 | -------------------------------------------------------------------------------- /src/ft_llama_lora/merge_llama_with_chinese_lora.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python medical_prompts/src/ft_llama_lora/merge_llama_with_chinese_lora.py \ 4 | --base_model ./models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348 \ 5 | --lora_model ./models--ziqingyang--chinese-llama-plus-lora-7b/snapshots/32115d9a87767a8e00464dc560030a12bf38cb24,./models--ziqingyang--chinese-alpaca-plus-lora-7b/snapshots/8f4c20016de3c4c9a6fb47bc7082583849a37285 \ 6 | --output_type huggingface \ 7 | --output_dir ./resources/chinese-llama-alpaca-plus-lora-7b 8 | """ 9 | import argparse 10 | import json 11 | import os 12 | import gc 13 | import torch 14 | 15 | import sys 16 | sys.path.append("./") 17 | 18 | import peft 19 | from peft import PeftModel 20 | from transformers import LlamaForCausalLM, LlamaTokenizer 21 | from huggingface_hub import hf_hub_download 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--base_model', default=None, required=True, 25 | type=str, help="Please specify a base_model") 26 | parser.add_argument('--lora_model', default=None, required=True, 27 | type=str, help="Please specify LoRA models to be merged (ordered); use commas to separate multiple LoRA models.") 28 | parser.add_argument('--offload_dir', default=None, type=str, 29 | help="(Optional) Please specify a temp folder for offloading (useful for low-RAM machines). Default None (disable offload).") 30 | parser.add_argument('--output_type', default='pth',choices=['pth','huggingface'], type=str, 31 | help="save the merged model in pth or huggingface format.") 32 | parser.add_argument('--output_dir', default='./', type=str) 33 | 34 | 35 | emb_to_model_size = { 36 | 4096 : '7B', 37 | 5120 : '13B', 38 | 6656 : '30B', 39 | 8192 : '65B', 40 | } 41 | num_shards_of_models = {'7B': 1, '13B': 2} 42 | params_of_models = { 43 | '7B': 44 | { 45 | "dim": 4096, 46 | "multiple_of": 256, 47 | "n_heads": 32, 48 | "n_layers": 32, 49 | "norm_eps": 1e-06, 50 | "vocab_size": -1, 51 | }, 52 | '13B': 53 | { 54 | "dim": 5120, 55 | "multiple_of": 256, 56 | "n_heads": 40, 57 | "n_layers": 40, 58 | "norm_eps": 1e-06, 59 | "vocab_size": -1, 60 | }, 61 | } 62 | 63 | def transpose(weight, fan_in_fan_out): 64 | return weight.T if fan_in_fan_out else weight 65 | 66 | # Borrowed and modified from https://github.com/tloen/alpaca-lora 67 | def translate_state_dict_key(k): 68 | k = k.replace("base_model.model.", "") 69 | if k == "model.embed_tokens.weight": 70 | return "tok_embeddings.weight" 71 | elif k == "model.norm.weight": 72 | return "norm.weight" 73 | elif k == "lm_head.weight": 74 | return "output.weight" 75 | elif k.startswith("model.layers."): 76 | layer = k.split(".")[2] 77 | if k.endswith(".self_attn.q_proj.weight"): 78 | return f"layers.{layer}.attention.wq.weight" 79 | elif k.endswith(".self_attn.k_proj.weight"): 80 | return f"layers.{layer}.attention.wk.weight" 81 | elif k.endswith(".self_attn.v_proj.weight"): 82 | return f"layers.{layer}.attention.wv.weight" 83 | elif k.endswith(".self_attn.o_proj.weight"): 84 | return f"layers.{layer}.attention.wo.weight" 85 | elif k.endswith(".mlp.gate_proj.weight"): 86 | return f"layers.{layer}.feed_forward.w1.weight" 87 | elif k.endswith(".mlp.down_proj.weight"): 88 | return f"layers.{layer}.feed_forward.w2.weight" 89 | elif k.endswith(".mlp.up_proj.weight"): 90 | return f"layers.{layer}.feed_forward.w3.weight" 91 | elif k.endswith(".input_layernorm.weight"): 92 | return f"layers.{layer}.attention_norm.weight" 93 | elif k.endswith(".post_attention_layernorm.weight"): 94 | return f"layers.{layer}.ffn_norm.weight" 95 | elif k.endswith("rotary_emb.inv_freq") or "lora" in k: 96 | return None 97 | else: 98 | print(layer, k) 99 | raise NotImplementedError 100 | else: 101 | print(k) 102 | raise NotImplementedError 103 | 104 | 105 | def unpermute(w): 106 | return ( 107 | w.view(n_heads, 2, dim // n_heads // 2, dim).transpose(1, 2).reshape(dim, dim) 108 | ) 109 | 110 | 111 | def save_shards(model_sd, num_shards: int): 112 | # Add the no_grad context manager 113 | with torch.no_grad(): 114 | if num_shards == 1: 115 | new_state_dict = {} 116 | for k, v in model_sd.items(): 117 | new_k = translate_state_dict_key(k) 118 | if new_k is not None: 119 | if "wq" in new_k or "wk" in new_k: 120 | new_state_dict[new_k] = unpermute(v) 121 | else: 122 | new_state_dict[new_k] = v 123 | 124 | os.makedirs(output_dir, exist_ok=True) 125 | print(f"Saving shard 1 of {num_shards} into {output_dir}/consolidated.00.pth") 126 | torch.save(new_state_dict, output_dir + "/consolidated.00.pth") 127 | with open(output_dir + "/params.json", "w") as f: 128 | json.dump(params, f) 129 | else: 130 | new_state_dicts = [dict() for _ in range(num_shards)] 131 | for k in list(model_sd.keys()): 132 | v = model_sd[k] 133 | new_k = translate_state_dict_key(k) 134 | if new_k is not None: 135 | if new_k=='tok_embeddings.weight': 136 | print(f"Processing {new_k}") 137 | assert v.size(1)%num_shards==0 138 | splits = v.split(v.size(1)//num_shards,dim=1) 139 | elif new_k=='output.weight': 140 | print(f"Processing {new_k}") 141 | splits = v.split(v.size(0)//num_shards,dim=0) 142 | 143 | elif new_k=='norm.weight': 144 | print(f"Processing {new_k}") 145 | splits = [v] * num_shards 146 | elif 'ffn_norm.weight' in new_k: 147 | print(f"Processing {new_k}") 148 | splits = [v] * num_shards 149 | elif 'attention_norm.weight' in new_k: 150 | print(f"Processing {new_k}") 151 | splits = [v] * num_shards 152 | 153 | 154 | elif 'w1.weight' in new_k: 155 | print(f"Processing {new_k}") 156 | splits = v.split(v.size(0)//num_shards,dim=0) 157 | elif 'w2.weight' in new_k: 158 | print(f"Processing {new_k}") 159 | splits = v.split(v.size(1)//num_shards,dim=1) 160 | elif 'w3.weight' in new_k: 161 | print(f"Processing {new_k}") 162 | splits = v.split(v.size(0)//num_shards,dim=0) 163 | 164 | 165 | elif 'wo.weight' in new_k: 166 | print(f"Processing {new_k}") 167 | splits = v.split(v.size(1)//num_shards,dim=1) 168 | 169 | elif 'wv.weight' in new_k: 170 | print(f"Processing {new_k}") 171 | splits = v.split(v.size(0)//num_shards,dim=0) 172 | 173 | elif "wq.weight" in new_k or "wk.weight" in new_k: 174 | print(f"Processing {new_k}") 175 | v = unpermute(v) 176 | splits = v.split(v.size(0)//num_shards,dim=0) 177 | else: 178 | print(f"Unexpected key {new_k}") 179 | raise ValueError 180 | for sd,split in zip(new_state_dicts,splits): 181 | sd[new_k] = split.clone() 182 | del split 183 | del splits 184 | del model_sd[k],v 185 | gc.collect() # Effectively enforce garbage collection 186 | 187 | os.makedirs(output_dir, exist_ok=True) 188 | for i,new_state_dict in enumerate(new_state_dicts): 189 | print(f"Saving shard {i+1} of {num_shards} into {output_dir}/consolidated.0{i}.pth") 190 | torch.save(new_state_dict, output_dir + f"/consolidated.0{i}.pth") 191 | with open(output_dir + "/params.json", "w") as f: 192 | print(f"Saving params.json into {output_dir}/params.json") 193 | json.dump(params, f) 194 | 195 | 196 | if __name__=='__main__': 197 | 198 | args = parser.parse_args() 199 | base_model_path = args.base_model 200 | lora_model_paths = [s.strip() for s in args.lora_model.split(',') if len(s.strip())!=0] 201 | output_dir = args.output_dir 202 | output_type = args.output_type 203 | offload_dir = args.offload_dir 204 | 205 | print(f"Base model: {base_model_path}") 206 | print(f"LoRA model(s) {lora_model_paths}:") 207 | 208 | if offload_dir is not None: 209 | # Load with offloading, which is useful for low-RAM machines. 210 | # Note that if you have enough RAM, please use original method instead, as it is faster. 211 | base_model = LlamaForCausalLM.from_pretrained( 212 | base_model_path, 213 | load_in_8bit=False, 214 | torch_dtype=torch.float16, 215 | offload_folder=offload_dir, 216 | offload_state_dict=True, 217 | low_cpu_mem_usage=True, 218 | device_map={"": "cpu"}, 219 | ) 220 | else: 221 | # Original method without offloading 222 | base_model = LlamaForCausalLM.from_pretrained( 223 | base_model_path, 224 | load_in_8bit=False, 225 | torch_dtype=torch.float16, 226 | device_map={"": "cpu"}, 227 | ) 228 | print(base_model) 229 | 230 | ## infer the model size from the checkpoint 231 | embedding_size = base_model.get_input_embeddings().weight.size(1) 232 | model_size = emb_to_model_size[embedding_size] 233 | print(f"Peft version: {peft.__version__}") 234 | print(f"Loading LoRA for {model_size} model") 235 | 236 | lora_model = None 237 | lora_model_sd = None 238 | for lora_index, lora_model_path in enumerate(lora_model_paths): 239 | print(f"Loading LoRA {lora_model_path}") 240 | tokenizer = LlamaTokenizer.from_pretrained(lora_model_path) 241 | if base_model.get_input_embeddings().weight.size(0) != len(tokenizer): 242 | base_model.resize_token_embeddings(len(tokenizer)) 243 | print(f"Extended vocabulary size to {len(tokenizer)}") 244 | 245 | first_weight = base_model.model.layers[0].self_attn.q_proj.weight 246 | first_weight_old = first_weight.clone() 247 | 248 | if hasattr(peft.LoraModel,'merge_and_unload'): 249 | lora_model = PeftModel.from_pretrained( 250 | base_model, 251 | lora_model_path, 252 | device_map={"": "cpu"}, 253 | torch_dtype=torch.float16, 254 | ) 255 | assert torch.allclose(first_weight_old, first_weight) 256 | print(f"Merging with merge_and_unload...") 257 | base_model = lora_model.merge_and_unload() 258 | else: 259 | base_model_sd = base_model.state_dict() 260 | try: 261 | lora_model_sd = torch.load(os.path.join(lora_model_path,'adapter_model.bin'),map_location='cpu') 262 | except FileNotFoundError: 263 | print("Cannot find lora model on the disk. Downloading lora model from hub...") 264 | filename = hf_hub_download(repo_id=lora_model_path,filename='adapter_model.bin') 265 | lora_model_sd = torch.load(filename,map_location='cpu') 266 | 267 | lora_config = peft.LoraConfig.from_pretrained(lora_model_path) 268 | lora_scaling = lora_config.lora_alpha / lora_config.r 269 | fan_in_fan_out = lora_config.fan_in_fan_out 270 | lora_keys = [k for k in lora_model_sd if 'lora_A' in k] 271 | non_lora_keys = [k for k in lora_model_sd if not 'lora_' in k] 272 | 273 | for k in non_lora_keys: 274 | print(f"merging {k}") 275 | original_k = k.replace('base_model.model.','') 276 | base_model_sd[original_k].copy_(lora_model_sd[k]) 277 | 278 | for k in lora_keys: 279 | print(f"merging {k}") 280 | original_key = k.replace('.lora_A','').replace('base_model.model.','') 281 | assert original_key in base_model_sd 282 | lora_a_key = k 283 | lora_b_key = k.replace('lora_A','lora_B') 284 | base_model_sd[original_key] += ( 285 | transpose(lora_model_sd[lora_b_key].float() @ lora_model_sd[lora_a_key].float(),fan_in_fan_out) * lora_scaling 286 | ) 287 | assert base_model_sd[original_key].dtype == torch.float16 288 | 289 | # did we do anything? 290 | assert not torch.allclose(first_weight_old, first_weight) 291 | 292 | tokenizer.save_pretrained(output_dir) 293 | 294 | if output_type=='huggingface': 295 | print("Saving to Hugging Face format...") 296 | LlamaForCausalLM.save_pretrained( 297 | base_model, output_dir, 298 | max_shard_size="2GB" 299 | ) #, state_dict=deloreanized_sd) 300 | else: # output_type=='pth 301 | print("Saving to pth format...") 302 | 303 | base_model_sd = base_model.state_dict() 304 | del lora_model, base_model, lora_model_sd 305 | 306 | params = params_of_models[model_size] 307 | num_shards = num_shards_of_models[model_size] 308 | n_layers = params["n_layers"] 309 | n_heads = params["n_heads"] 310 | dim = params["dim"] 311 | dims_per_head = dim // n_heads 312 | base = 10000.0 313 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) 314 | 315 | save_shards(model_sd=base_model_sd, num_shards=num_shards) 316 | -------------------------------------------------------------------------------- /src/ft_llama_lora/run_train.sh: -------------------------------------------------------------------------------- 1 | lr=1e-4 2 | lora_rank=8 3 | lora_trainable="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" 4 | #lora_trainable="q_proj,v_proj,k_proj,o_proj" 5 | modules_to_save="embed_tokens,lm_head" 6 | lora_dropout=0.1 7 | pretrained_model="./resources/chinese-llama-plus-lora-7b" 8 | dataset_name="./data/promptcblue/test_a_open_0" 9 | dataset_cache_dir="./data/promptcblue/test_a_open_0" 10 | per_device_batch_size=2 11 | per_device_batch_size=2 12 | gradient_accumulation_steps=16 13 | training_steps=10000 14 | output_dir="./experiments/output/promptcblue-llama-7b-pt-v0" 15 | # deepspeed_config_file="src/chatmed_llama_peft/deepspeed_config_zero3_offload.json" 16 | 17 | torchrun \ 18 | --nnodes 1 \ 19 | --nproc_per_node 2 \ 20 | --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:12356 \ 21 | src/ft_llama_lora/run_clm_pt_with_peft.py \ 22 | --model_name_or_path ${pretrained_model} \ 23 | --tokenizer_name_or_path ${pretrained_model} \ 24 | --dataset_name ${dataset_name} \ 25 | --dataset_cache_dir ${dataset_cache_dir} \ 26 | --validation_split_percentage 0.001 \ 27 | --per_device_train_batch_size ${per_device_batch_size} \ 28 | --per_device_eval_batch_size ${per_device_batch_size} \ 29 | --do_train \ 30 | --seed 100 \ 31 | --fp16 \ 32 | --max_steps ${training_steps} \ 33 | --lr_scheduler_type cosine \ 34 | --leraning_rate ${lr} \ 35 | --warmup_ratio 0.05 \ 36 | --weight_decay 0.01 \ 37 | --logging_strategy steps \ 38 | --logging_steps 10 \ 39 | --save_strategy steps \ 40 | --save_total_limit 25 \ 41 | --save_steps 100 \ 42 | --gradient_accumulation_steps ${gradient_accumulation_steps} \ 43 | --preprocessing_num_workers 8 \ 44 | --block_size 512 \ 45 | --output_dir ${output_dir} \ 46 | --overwrite_output_dir \ 47 | --ddp_timeout 30000 \ 48 | --logging_first_step True \ 49 | --lora_rank ${lora_rank} \ 50 | --trainable ${lora_trainable} \ 51 | --modules_to_save ${modules_to_save} \ 52 | --lora_dropout ${lora_dropout} \ 53 | --torch_dtype float16 54 | 55 | # --deepspeed ${deepspeed_config_file} \ 56 | 57 | # -------------------------------------------------------------------------------- /src/ft_llama_lora/vllm_serving/launch_vllm.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from fastapi.responses import StreamingResponse 4 | from fastapi import FastAPI, Request 5 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, pipeline 6 | from queue import Queue 7 | from typing import List, Dict, Optional 8 | import time 9 | import asyncio 10 | import argparse 11 | import threading 12 | from dataclasses import dataclass 13 | from vllm import EngineArgs, SamplingParams 14 | 15 | from utils import Prompter 16 | from llm_engine import LLMEngine 17 | 18 | app = FastAPI() 19 | 20 | 21 | @dataclass 22 | class GenerationInputs: 23 | req_id: int 24 | prompt: str 25 | sampling_config: dict 26 | 27 | 28 | @dataclass 29 | class GenerationOutput: 30 | req_id: int 31 | generated_text: str 32 | num_output_tokens: int 33 | error: str 34 | 35 | 36 | class ModelThread: 37 | def __init__(self, vllm_args, model_ready_event, progress_call, loop): 38 | self.vllm_args = vllm_args 39 | self.model_ready_event = model_ready_event 40 | self.thread = None 41 | self.input_queue = Queue() 42 | self.output_queue = Queue() 43 | 44 | self.progress_call = progress_call 45 | self.loop = loop 46 | 47 | def start_thread(self): 48 | self.thread = threading.Thread(target=self._thread, daemon=True) 49 | self.thread.start() 50 | 51 | def _thread(self): 52 | server = self.init_model(self.vllm_args) 53 | 54 | self.model_ready_event.set() 55 | 56 | while True: 57 | time.sleep(0.01) 58 | 59 | while not self.input_queue.empty(): 60 | gen_input = self.input_queue.get_nowait() 61 | 62 | prompt = gen_input.prompt 63 | sampling_params = SamplingParams( 64 | n=1, 65 | best_of=gen_input.sampling_config.get('best_of', 1), 66 | use_beam_search=gen_input.sampling_config.get('use_beam_search', False), 67 | top_p=gen_input.sampling_config.get('top_p', 1.0), 68 | top_k=gen_input.sampling_config.get('top_k', -1), 69 | max_tokens=gen_input.sampling_config.get('max_tokens', 512), 70 | presence_penalty=gen_input.sampling_config.get('presence_penalty', 0.2), 71 | frequency_penalty=gen_input.sampling_config.get('frequency_penalty', 0.2), 72 | temperature=gen_input.sampling_config.get('temperature', 1e-6), 73 | ) 74 | 75 | req_id = gen_input.req_id 76 | 77 | server.add_request( 78 | str(req_id), 79 | prompt, 80 | sampling_params, 81 | ) 82 | 83 | vllm_outputs = server.step() 84 | 85 | needs_call_progress = False 86 | for vllm_output in vllm_outputs: 87 | if not vllm_output.finished: 88 | continue 89 | 90 | needs_call_progress = True 91 | assert len(vllm_output.outputs) == 1 92 | req_id = int(vllm_output.request_id) 93 | generated_text = vllm_output.outputs[0].text 94 | num_output_tokens = len(vllm_output.outputs[0].token_ids) 95 | 96 | gen_output = GenerationOutput( 97 | req_id=req_id, 98 | generated_text=generated_text, 99 | num_output_tokens=num_output_tokens, 100 | error=None, 101 | ) 102 | self.output_queue.put_nowait(gen_output) 103 | 104 | if needs_call_progress: 105 | asyncio.run_coroutine_threadsafe(self.progress_call(), loop) 106 | 107 | @staticmethod 108 | def init_model(vllm_args): 109 | print('Init model') 110 | server_args = EngineArgs.from_cli_args(vllm_args) 111 | server = LLMEngine.from_engine_args(server_args) 112 | print('Model ready') 113 | return server 114 | 115 | 116 | class FastAPIServer: 117 | def __init__(self, loop, vllm_args): 118 | self.model_ready_event = asyncio.Event() 119 | 120 | self.requests = {} 121 | self.generations = {} 122 | self.request_queue = [] 123 | self._next_req_id = 0 124 | 125 | self.loop = loop 126 | 127 | self.model_thread = ModelThread( 128 | vllm_args, self.model_ready_event, self.progress_async, self.loop) 129 | self.model_thread.start_thread() 130 | 131 | @property 132 | def next_req_id(self): 133 | rval = self._next_req_id 134 | self._next_req_id += 1 135 | return rval 136 | 137 | async def progress_async(self): 138 | return self.progress() 139 | 140 | def progress(self): 141 | sent_to_model = 0 142 | recv_from_model = 0 143 | 144 | for req_id in self.request_queue: 145 | prompt, sampling_config = self.requests[req_id] 146 | gen_inputs = GenerationInputs( 147 | req_id, 148 | prompt, 149 | sampling_config, 150 | ) 151 | self.model_thread.input_queue.put_nowait(gen_inputs) 152 | sent_to_model += 1 153 | self.request_queue = [] 154 | 155 | found_outputs = [] 156 | while not self.model_thread.output_queue.empty(): 157 | gen_output = self.model_thread.output_queue.get_nowait() 158 | found_outputs.append(gen_output) 159 | recv_from_model += 1 160 | 161 | for output in found_outputs: 162 | req_id = output.req_id 163 | ready_event, _, _, _ = self.generations[req_id] 164 | self.generations[req_id] = ( 165 | ready_event, output.generated_text, output.num_output_tokens, output.error) 166 | ready_event.set() 167 | 168 | print(f'progress {sent_to_model=} {recv_from_model=}') 169 | 170 | async def is_ready(self): 171 | return self.model_ready_event.is_set() 172 | 173 | def add_request(self, prompt, sampling_config): 174 | req_id = self.next_req_id 175 | self.requests[req_id] = (prompt, sampling_config) 176 | self.request_queue.append(req_id) 177 | 178 | ready_event = asyncio.Event() 179 | self.generations[req_id] = (ready_event, None, None, None) 180 | return req_id 181 | 182 | async def get_generation(self, req_id): 183 | ready_event, _, _, _ = self.generations[req_id] 184 | await ready_event.wait() 185 | _, generation, num_output_tokens, error = self.generations[req_id] 186 | 187 | del self.generations[req_id] 188 | del self.requests[req_id] 189 | return generation, num_output_tokens, error 190 | 191 | async def generate(self, request_dict: Dict): 192 | global prompter 193 | 194 | instruction = request_dict.get('instruction') 195 | input = request_dict['input'] 196 | prompt = prompter.generate_prompt(instruction=instruction, input=input) 197 | sampling_config = request_dict['parameters'] 198 | 199 | req_id = self.add_request(prompt, sampling_config) 200 | self.progress() 201 | generation, num_output_tokens, error = await self.get_generation(req_id) 202 | 203 | return { 204 | 'generated_text': generation, 205 | 'num_output_tokens_cf': num_output_tokens, 206 | 'error': error, 207 | } 208 | 209 | 210 | @app.post("/generate") 211 | async def generate_stream(request: Request): 212 | request_dict = await request.json() 213 | return await server.generate(request_dict) 214 | 215 | 216 | @app.get("/is_ready") 217 | async def is_ready(request: Request): 218 | return await server.is_ready() 219 | 220 | 221 | if __name__ == "__main__": 222 | parser = argparse.ArgumentParser() 223 | parser.add_argument('--port', type=int, required=True) 224 | parser.add_argument('--template_path', type=str, default='data/templates/alpaca.json') 225 | EngineArgs.add_cli_args(parser) 226 | args = parser.parse_args() 227 | 228 | vllm_args = EngineArgs.from_cli_args(args) 229 | 230 | prompter = Prompter(args.template_path) 231 | 232 | loop = asyncio.new_event_loop() 233 | server = FastAPIServer(loop, vllm_args) 234 | 235 | from uvicorn import Config, Server 236 | 237 | config = Config(app=app, loop=loop, host='0.0.0.0', 238 | port=args.port, log_level="info") 239 | uvicorn_server = Server(config) 240 | 241 | loop.run_until_complete(uvicorn_server.serve()) 242 | 243 | 244 | ''' 245 | max_num_batched_tokens=8000 246 | 247 | CUDA_VISIBLE_DEVICES="3" python src/promptcblue_llama_peft/vllm_serving/launch_vllm.py \ 248 | --port 8090 \ 249 | --model /public/home/xlwang2/codes/ChatMed/experiments/output/promptcblue-llama-7b-pt-v0/checkpoint-800-merge \ 250 | --use-np-weights \ 251 | --max-num-batched-tokens 4096 \ 252 | --dtype half \ 253 | --tensor-parallel-size 1 254 | ''' -------------------------------------------------------------------------------- /src/ft_llama_lora/vllm_serving/merge_llama_with_lora.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | CUDA_VISIBLE_DEVICES="3" python src/promptcblue_llama_peft/vllm_serving/merge_llama_with_lora.py \ 4 | --base_model /public/home/xlwang2/codes/Med_Prompts/resources/chinese-llama-plus-lora-7b \ 5 | --lora_model ./experiments/output/promptcblue-llama-7b-pt-v0/checkpoint-800 \ 6 | --output_type huggingface \ 7 | --output_dir ./experiments/output/promptcblue-llama-7b-pt-v0/checkpoint-800-merge 8 | """ 9 | import argparse 10 | import json 11 | import os 12 | import gc 13 | import torch 14 | 15 | import sys 16 | sys.path.append("./") 17 | 18 | import peft 19 | from peft import PeftModel 20 | from transformers import LlamaForCausalLM, LlamaTokenizer 21 | from huggingface_hub import hf_hub_download 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--base_model', default=None, required=True, 25 | type=str, help="Please specify a base_model") 26 | parser.add_argument('--lora_model', default=None, required=True, 27 | type=str, help="Please specify LoRA models to be merged (ordered); use commas to separate multiple LoRA models.") 28 | parser.add_argument('--offload_dir', default=None, type=str, 29 | help="(Optional) Please specify a temp folder for offloading (useful for low-RAM machines). Default None (disable offload).") 30 | parser.add_argument('--output_type', default='pth',choices=['pth','huggingface'], type=str, 31 | help="save the merged model in pth or huggingface format.") 32 | parser.add_argument('--output_dir', default='./', type=str) 33 | 34 | 35 | emb_to_model_size = { 36 | 4096 : '7B', 37 | 5120 : '13B', 38 | 6656 : '30B', 39 | 8192 : '65B', 40 | } 41 | num_shards_of_models = {'7B': 1, '13B': 2} 42 | params_of_models = { 43 | '7B': 44 | { 45 | "dim": 4096, 46 | "multiple_of": 256, 47 | "n_heads": 32, 48 | "n_layers": 32, 49 | "norm_eps": 1e-06, 50 | "vocab_size": -1, 51 | }, 52 | '13B': 53 | { 54 | "dim": 5120, 55 | "multiple_of": 256, 56 | "n_heads": 40, 57 | "n_layers": 40, 58 | "norm_eps": 1e-06, 59 | "vocab_size": -1, 60 | }, 61 | } 62 | 63 | def transpose(weight, fan_in_fan_out): 64 | return weight.T if fan_in_fan_out else weight 65 | 66 | # Borrowed and modified from https://github.com/tloen/alpaca-lora 67 | def translate_state_dict_key(k): 68 | k = k.replace("base_model.model.", "") 69 | if k == "model.embed_tokens.weight": 70 | return "tok_embeddings.weight" 71 | elif k == "model.norm.weight": 72 | return "norm.weight" 73 | elif k == "lm_head.weight": 74 | return "output.weight" 75 | elif k.startswith("model.layers."): 76 | layer = k.split(".")[2] 77 | if k.endswith(".self_attn.q_proj.weight"): 78 | return f"layers.{layer}.attention.wq.weight" 79 | elif k.endswith(".self_attn.k_proj.weight"): 80 | return f"layers.{layer}.attention.wk.weight" 81 | elif k.endswith(".self_attn.v_proj.weight"): 82 | return f"layers.{layer}.attention.wv.weight" 83 | elif k.endswith(".self_attn.o_proj.weight"): 84 | return f"layers.{layer}.attention.wo.weight" 85 | elif k.endswith(".mlp.gate_proj.weight"): 86 | return f"layers.{layer}.feed_forward.w1.weight" 87 | elif k.endswith(".mlp.down_proj.weight"): 88 | return f"layers.{layer}.feed_forward.w2.weight" 89 | elif k.endswith(".mlp.up_proj.weight"): 90 | return f"layers.{layer}.feed_forward.w3.weight" 91 | elif k.endswith(".input_layernorm.weight"): 92 | return f"layers.{layer}.attention_norm.weight" 93 | elif k.endswith(".post_attention_layernorm.weight"): 94 | return f"layers.{layer}.ffn_norm.weight" 95 | elif k.endswith("rotary_emb.inv_freq") or "lora" in k: 96 | return None 97 | else: 98 | print(layer, k) 99 | raise NotImplementedError 100 | else: 101 | print(k) 102 | raise NotImplementedError 103 | 104 | 105 | def unpermute(w): 106 | return ( 107 | w.view(n_heads, 2, dim // n_heads // 2, dim).transpose(1, 2).reshape(dim, dim) 108 | ) 109 | 110 | 111 | def save_shards(model_sd, num_shards: int): 112 | # Add the no_grad context manager 113 | with torch.no_grad(): 114 | if num_shards == 1: 115 | new_state_dict = {} 116 | for k, v in model_sd.items(): 117 | new_k = translate_state_dict_key(k) 118 | if new_k is not None: 119 | if "wq" in new_k or "wk" in new_k: 120 | new_state_dict[new_k] = unpermute(v) 121 | else: 122 | new_state_dict[new_k] = v 123 | 124 | os.makedirs(output_dir, exist_ok=True) 125 | print(f"Saving shard 1 of {num_shards} into {output_dir}/consolidated.00.pth") 126 | torch.save(new_state_dict, output_dir + "/consolidated.00.pth") 127 | with open(output_dir + "/params.json", "w") as f: 128 | json.dump(params, f) 129 | else: 130 | new_state_dicts = [dict() for _ in range(num_shards)] 131 | for k in list(model_sd.keys()): 132 | v = model_sd[k] 133 | new_k = translate_state_dict_key(k) 134 | if new_k is not None: 135 | if new_k=='tok_embeddings.weight': 136 | print(f"Processing {new_k}") 137 | assert v.size(1)%num_shards==0 138 | splits = v.split(v.size(1)//num_shards,dim=1) 139 | elif new_k=='output.weight': 140 | print(f"Processing {new_k}") 141 | splits = v.split(v.size(0)//num_shards,dim=0) 142 | 143 | elif new_k=='norm.weight': 144 | print(f"Processing {new_k}") 145 | splits = [v] * num_shards 146 | elif 'ffn_norm.weight' in new_k: 147 | print(f"Processing {new_k}") 148 | splits = [v] * num_shards 149 | elif 'attention_norm.weight' in new_k: 150 | print(f"Processing {new_k}") 151 | splits = [v] * num_shards 152 | 153 | 154 | elif 'w1.weight' in new_k: 155 | print(f"Processing {new_k}") 156 | splits = v.split(v.size(0)//num_shards,dim=0) 157 | elif 'w2.weight' in new_k: 158 | print(f"Processing {new_k}") 159 | splits = v.split(v.size(1)//num_shards,dim=1) 160 | elif 'w3.weight' in new_k: 161 | print(f"Processing {new_k}") 162 | splits = v.split(v.size(0)//num_shards,dim=0) 163 | 164 | 165 | elif 'wo.weight' in new_k: 166 | print(f"Processing {new_k}") 167 | splits = v.split(v.size(1)//num_shards,dim=1) 168 | 169 | elif 'wv.weight' in new_k: 170 | print(f"Processing {new_k}") 171 | splits = v.split(v.size(0)//num_shards,dim=0) 172 | 173 | elif "wq.weight" in new_k or "wk.weight" in new_k: 174 | print(f"Processing {new_k}") 175 | v = unpermute(v) 176 | splits = v.split(v.size(0)//num_shards,dim=0) 177 | else: 178 | print(f"Unexpected key {new_k}") 179 | raise ValueError 180 | for sd,split in zip(new_state_dicts,splits): 181 | sd[new_k] = split.clone() 182 | del split 183 | del splits 184 | del model_sd[k],v 185 | gc.collect() # Effectively enforce garbage collection 186 | 187 | os.makedirs(output_dir, exist_ok=True) 188 | for i,new_state_dict in enumerate(new_state_dicts): 189 | print(f"Saving shard {i+1} of {num_shards} into {output_dir}/consolidated.0{i}.pth") 190 | torch.save(new_state_dict, output_dir + f"/consolidated.0{i}.pth") 191 | with open(output_dir + "/params.json", "w") as f: 192 | print(f"Saving params.json into {output_dir}/params.json") 193 | json.dump(params, f) 194 | 195 | 196 | if __name__=='__main__': 197 | 198 | args = parser.parse_args() 199 | base_model_path = args.base_model 200 | lora_model_paths = [s.strip() for s in args.lora_model.split(',') if len(s.strip())!=0] 201 | output_dir = args.output_dir 202 | output_type = args.output_type 203 | offload_dir = args.offload_dir 204 | 205 | print(f"Base model: {base_model_path}") 206 | print(f"LoRA model(s) {lora_model_paths}:") 207 | 208 | if offload_dir is not None: 209 | # Load with offloading, which is useful for low-RAM machines. 210 | # Note that if you have enough RAM, please use original method instead, as it is faster. 211 | base_model = LlamaForCausalLM.from_pretrained( 212 | base_model_path, 213 | load_in_8bit=False, 214 | torch_dtype=torch.float16, 215 | offload_folder=offload_dir, 216 | offload_state_dict=True, 217 | low_cpu_mem_usage=True, 218 | device_map={"": "cpu"}, 219 | ) 220 | else: 221 | # Original method without offloading 222 | base_model = LlamaForCausalLM.from_pretrained( 223 | base_model_path, 224 | load_in_8bit=False, 225 | torch_dtype=torch.float16, 226 | device_map={"": "cpu"}, 227 | ) 228 | print(base_model) 229 | 230 | ## infer the model size from the checkpoint 231 | embedding_size = base_model.get_input_embeddings().weight.size(1) 232 | model_size = emb_to_model_size[embedding_size] 233 | print(f"Peft version: {peft.__version__}") 234 | print(f"Loading LoRA for {model_size} model") 235 | 236 | lora_model = None 237 | lora_model_sd = None 238 | for lora_index, lora_model_path in enumerate(lora_model_paths): 239 | print(f"Loading LoRA {lora_model_path}") 240 | tokenizer = LlamaTokenizer.from_pretrained(lora_model_path) 241 | assert base_model.get_input_embeddings().weight.size(0) == len(tokenizer) 242 | 243 | # if base_model.get_input_embeddings().weight.size(0) != len(tokenizer): 244 | # base_model.resize_token_embeddings(len(tokenizer)) 245 | # print(f"Extended vocabulary size to {len(tokenizer)}") 246 | 247 | first_weight = base_model.model.layers[0].self_attn.q_proj.weight 248 | first_weight_old = first_weight.clone() 249 | 250 | if hasattr(peft.LoraModel, 'merge_and_unload'): 251 | lora_model = PeftModel.from_pretrained( 252 | base_model, 253 | lora_model_path, 254 | device_map={"": "cpu"}, 255 | torch_dtype=torch.float16, 256 | ) 257 | assert torch.allclose(first_weight_old, first_weight) 258 | print(f"Merging with merge_and_unload...") 259 | base_model = lora_model.merge_and_unload() 260 | else: 261 | base_model_sd = base_model.state_dict() 262 | try: 263 | lora_model_sd = torch.load(os.path.join(lora_model_path,'adapter_model.bin'),map_location='cpu') 264 | except FileNotFoundError: 265 | print("Cannot find lora model on the disk. Downloading lora model from hub...") 266 | filename = hf_hub_download(repo_id=lora_model_path,filename='adapter_model.bin') 267 | lora_model_sd = torch.load(filename,map_location='cpu') 268 | 269 | lora_config = peft.LoraConfig.from_pretrained(lora_model_path) 270 | lora_scaling = lora_config.lora_alpha / lora_config.r 271 | fan_in_fan_out = lora_config.fan_in_fan_out 272 | lora_keys = [k for k in lora_model_sd if 'lora_A' in k] 273 | non_lora_keys = [k for k in lora_model_sd if not 'lora_' in k] 274 | 275 | for k in non_lora_keys: 276 | print(f"merging {k}") 277 | original_k = k.replace('base_model.model.','') 278 | base_model_sd[original_k].copy_(lora_model_sd[k]) 279 | 280 | for k in lora_keys: 281 | print(f"merging {k}") 282 | original_key = k.replace('.lora_A','').replace('base_model.model.','') 283 | assert original_key in base_model_sd 284 | lora_a_key = k 285 | lora_b_key = k.replace('lora_A','lora_B') 286 | base_model_sd[original_key] += ( 287 | transpose(lora_model_sd[lora_b_key].float() @ lora_model_sd[lora_a_key].float(),fan_in_fan_out) * lora_scaling 288 | ) 289 | assert base_model_sd[original_key].dtype == torch.float16 290 | 291 | # did we do anything? 292 | assert not torch.allclose(first_weight_old, first_weight) 293 | 294 | tokenizer.save_pretrained(output_dir) 295 | 296 | if output_type=='huggingface': 297 | print("Saving to Hugging Face format...") 298 | LlamaForCausalLM.save_pretrained( 299 | base_model, output_dir, 300 | max_shard_size="2GB" 301 | ) #, state_dict=deloreanized_sd) 302 | else: # output_type=='pth 303 | print("Saving to pth format...") 304 | 305 | base_model_sd = base_model.state_dict() 306 | del lora_model, base_model, lora_model_sd 307 | 308 | params = params_of_models[model_size] 309 | num_shards = num_shards_of_models[model_size] 310 | n_layers = params["n_layers"] 311 | n_heads = params["n_heads"] 312 | dim = params["dim"] 313 | dims_per_head = dim // n_heads 314 | base = 10000.0 315 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) 316 | 317 | save_shards(model_sd=base_model_sd, num_shards=num_shards) 318 | -------------------------------------------------------------------------------- /src/ft_llama_lora/vllm_serving/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path as osp 3 | from typing import Union 4 | 5 | 6 | class Prompter(object): 7 | __slots__ = ("template", "_verbose") 8 | 9 | def __init__(self, template_path: str = "", verbose: bool = False): 10 | self._verbose = verbose 11 | # if not osp.exists(template_path): 12 | # raise ValueError(f"Can't read {template_path}") 13 | # 14 | # with open(template_path) as fp: 15 | # self.template = json.load(fp) 16 | 17 | self.template = { 18 | "prompt_input": "问:\n\n答:\n", 19 | "response_split": "\n答:\n", 20 | } 21 | 22 | def generate_prompt( 23 | self, 24 | instruction: str, 25 | input: Union[None, str] = None, 26 | label: Union[None, str] = None, 27 | ) -> str: 28 | # returns the full prompt from instruction and optional input 29 | # if a label (=response, =output) is provided, it's also appended. 30 | res = self.template["prompt_input"].replace("", input) 31 | 32 | if label: 33 | res = f"{res}{label}" 34 | if self._verbose: 35 | print(res) 36 | return res 37 | 38 | def get_response(self, output: str) -> str: 39 | return output.split(self.template["response_split"])[1].strip() -------------------------------------------------------------------------------- /src/ft_llama_lora/vllm_serving/web_service_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Created by Michael Zhu 3 | # DataSelect AI, 2023 4 | 5 | import json 6 | import time 7 | 8 | import urllib.request 9 | 10 | import sys 11 | sys.path.append("./") 12 | 13 | 14 | def test_service(input_text): 15 | header = {'Content-Type': 'application/json'} 16 | 17 | prompt = "问:\n{}\n答:\n".format(input_text.strip()) 18 | 19 | data = { 20 | "input": input_text.strip().replace("答:", ""), 21 | "parameters": {}, 22 | } 23 | request = urllib.request.Request( 24 | url='http://127.0.0.1:8090/generate', 25 | headers=header, 26 | data=json.dumps(data).encode('utf-8') 27 | ) 28 | 29 | result = None 30 | try: 31 | response = urllib.request.urlopen(request, timeout=30) 32 | res = response.read().decode('utf-8') 33 | result = json.loads(res) 34 | print(json.dumps(data, ensure_ascii=False, indent=2)) 35 | print(json.dumps(result, ensure_ascii=False, indent=2)) 36 | 37 | except Exception as e: 38 | print(e) 39 | 40 | return result 41 | 42 | 43 | if __name__ == "__main__": 44 | 45 | f_out = open("src/web_services/test_examples/test_preds_0.json", "a", encoding="utf-8", buffering=1) 46 | with open("data/promptcblue/test_a_open_0/dev.json", "r", encoding="utf-8") as f: 47 | 48 | for line in f: 49 | line = line.strip() 50 | if not line: 51 | continue 52 | 53 | line = json.loads(line) 54 | 55 | t0 = time.time() 56 | result = test_service(line["input"]) 57 | t1 = time.time() 58 | print("time cost: ", t1 - t0) 59 | 60 | f_out.write( 61 | json.dumps(result, ensure_ascii=False) + "\n" 62 | ) 63 | 64 | --------------------------------------------------------------------------------