├── CITATION.cff ├── MedQA-ChatGLM ├── __init__.py ├── export_weights.py ├── finetune.py ├── infer.py ├── load_export_weights.py ├── train_ppo.py ├── train_rm.py ├── utils │ ├── __init__.py │ ├── common.py │ ├── config.py │ ├── other.py │ ├── pairwise.py │ ├── ppo.py │ └── seq2seq.py └── web_demo.py ├── README.md ├── data ├── bulid_CMD.py ├── bulid_cMedQA.py ├── bulid_cMedQA2.py ├── comparison_gpt4_data_en.json ├── comparison_gpt4_data_zh.json ├── dataset_info-plus.json ├── dataset_info.json ├── get_decoder_type.py ├── merge-CMD.py ├── merge-MedDialog.py ├── merge-cMedQA.py ├── pre-MedDialog.py └── self_cognition.json ├── docs ├── Understanding_ChatGPT.pdf └── 参数详解.md ├── images ├── data-plus.png ├── data-plus2.png └── model.png └── requirements.txt /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: MedQA-ChatGLM 3 | message: >- 4 | If you use this software, please cite it using these 5 | metadata. 6 | type: software 7 | authors: 8 | - given-names: Rongsheng Wang 9 | orcid: https://orcid.org/my-orcid?orcid=0000-0003-2390-5999 10 | repository-code: 'https://github.com/WangRongsheng/MedQA-ChatGLM' 11 | url: 'https://github.com/WangRongsheng/MedQA-ChatGLM' 12 | abstract: A Medical QA Model Fine-tuned on ChatGLM Using Multiple fine-tuning Method and Real Medical QA Data 13 | license: CC BY-NC-SA 4.0 14 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | load_pretrained, 3 | ModelArguments 4 | ) 5 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/export_weights.py: -------------------------------------------------------------------------------- 1 | from utils import load_pretrained, ModelArguments 2 | import argparse 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | # 添加参数 7 | parser.add_argument('-finetuning_weights_path', '--finetuning_weights_path', dest='fwp', type=str, required=True) 8 | parser.add_argument('-save_weights_path', '--save_weights_path', dest='swp', type=str, required=True) 9 | # 解析参数 10 | args = parser.parse_args() 11 | 12 | model_args = ModelArguments(checkpoint_dir=args.fwp) 13 | model, tokenizer = load_pretrained(model_args) 14 | # 保存合并权重 15 | model.base_model.model.save_pretrained(args.swp) 16 | # 保存 Tokenizer 17 | tokenizer.save_pretrained(args.swp) 18 | 19 | print('合并模型完成,保存在:', args.swp) -------------------------------------------------------------------------------- /MedQA-ChatGLM/finetune.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Implements several parameter-efficient supervised fine-tuning method for ChatGLM. 3 | # This code is inspired by https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py 4 | 5 | 6 | from utils import ( 7 | load_pretrained, 8 | prepare_args, 9 | prepare_data, 10 | preprocess_data, 11 | plot_loss, 12 | Seq2SeqDataCollatorForChatGLM, 13 | ComputeMetrics, 14 | Seq2SeqTrainerForChatGLM 15 | ) 16 | 17 | 18 | def main(): 19 | 20 | # Prepare pretrained model and dataset 21 | model_args, data_args, training_args, finetuning_args = prepare_args() 22 | dataset = prepare_data(model_args, data_args) 23 | model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, training_args.do_train, stage="sft") 24 | dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft") 25 | data_collator = Seq2SeqDataCollatorForChatGLM( 26 | tokenizer=tokenizer, 27 | model=model, 28 | ignore_pad_token_for_loss=data_args.ignore_pad_token_for_loss, 29 | inference_mode=(not training_args.do_train) 30 | ) 31 | 32 | # Override the decoding parameters of Trainer 33 | training_args.generation_max_length = training_args.generation_max_length if \ 34 | training_args.generation_max_length is not None else data_args.max_target_length 35 | training_args.generation_num_beams = data_args.num_beams if \ 36 | data_args.num_beams is not None else training_args.generation_num_beams 37 | 38 | # Initialize our Trainer 39 | trainer = Seq2SeqTrainerForChatGLM( 40 | finetuning_args=finetuning_args, 41 | model=model, 42 | args=training_args, 43 | train_dataset=dataset if training_args.do_train else None, 44 | eval_dataset=dataset if training_args.do_eval else None, 45 | tokenizer=tokenizer, 46 | data_collator=data_collator, 47 | compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None 48 | ) 49 | 50 | # Keyword arguments for `model.generate` 51 | gen_kwargs = { 52 | "do_sample": True, 53 | "top_p": 0.7, 54 | "max_length": 768, 55 | "temperature": 0.95 56 | } 57 | 58 | # Training 59 | if training_args.do_train: 60 | train_result = trainer.train() 61 | trainer.log_metrics("train", train_result.metrics) 62 | trainer.save_metrics("train", train_result.metrics) 63 | trainer.save_state() # along with the loss values 64 | trainer.save_model() 65 | if finetuning_args.plot_loss: 66 | plot_loss(training_args) 67 | 68 | # Evaluation 69 | if training_args.do_eval: 70 | metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) 71 | trainer.log_metrics("eval", metrics) 72 | trainer.save_metrics("eval", metrics) 73 | 74 | # Predict 75 | if training_args.do_predict: 76 | predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) 77 | trainer.log_metrics("predict", predict_results.metrics) 78 | trainer.save_metrics("predict", predict_results.metrics) 79 | trainer.save_predictions(predict_results, tokenizer) 80 | 81 | 82 | def _mp_fn(index): 83 | # For xla_spawn (TPUs) 84 | main() 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/infer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Implement stream chat in command line for ChatGLM fine-tuned with PEFT. 3 | # This code is largely borrowed from https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py 4 | 5 | 6 | import os 7 | import signal 8 | import platform 9 | 10 | from utils import ModelArguments, load_pretrained 11 | from transformers import HfArgumentParser 12 | 13 | 14 | os_name = platform.system() 15 | clear_command = "cls" if os_name == "Windows" else "clear" 16 | stop_stream = False 17 | welcome = "欢迎使用 MedQA-ChatGLM 模型,输入内容即可对话,clear清空对话历史,stop终止程序" 18 | 19 | 20 | def build_prompt(history): 21 | prompt = welcome 22 | for query, response in history: 23 | prompt += f"\n\nUser: {query}" 24 | prompt += f"\n\nChatGLM-6B: {response}" 25 | return prompt 26 | 27 | 28 | def signal_handler(signal, frame): 29 | global stop_stream 30 | stop_stream = True 31 | 32 | 33 | def main(): 34 | 35 | global stop_stream 36 | parser = HfArgumentParser(ModelArguments) 37 | model_args, = parser.parse_args_into_dataclasses() 38 | model, tokenizer = load_pretrained(model_args) 39 | model = model.cuda() 40 | model.eval() 41 | 42 | history = [] 43 | print(welcome) 44 | while True: 45 | try: 46 | query = input("\nInput: ") 47 | except UnicodeDecodeError: 48 | print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") 49 | continue 50 | except Exception: 51 | raise 52 | 53 | if query.strip() == "stop": 54 | break 55 | if query.strip() == "clear": 56 | history = [] 57 | os.system(clear_command) 58 | print(welcome) 59 | continue 60 | 61 | count = 0 62 | for _, history in model.stream_chat(tokenizer, query, history=history): 63 | if stop_stream: 64 | stop_stream = False 65 | break 66 | else: 67 | count += 1 68 | if count % 8 == 0: 69 | os.system(clear_command) 70 | print(build_prompt(history), flush=True) 71 | signal.signal(signal.SIGINT, signal_handler) 72 | os.system(clear_command) 73 | print(build_prompt(history), flush=True) 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/load_export_weights.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoTokenizer 2 | import argparse 3 | 4 | def get_model(load_path): 5 | tokenizer = AutoTokenizer.from_pretrained(load_path, trust_remote_code=True) 6 | config = AutoConfig.from_pretrained(load_path, trust_remote_code=True, pre_seq_len=128) 7 | model = AutoModel.from_pretrained(load_path, config=config, trust_remote_code=True).half().cuda() 8 | model = model.eval() 9 | 10 | return tokenizer, model 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | # 添加参数 15 | parser.add_argument('-save_weights_path', '--save_weights_path', dest='swp', type=str, required=True) 16 | # 解析参数 17 | args = parser.parse_args() 18 | 19 | tokenizer, model = get_model(args.swp) 20 | print(model) 21 | print('加载完成') 22 | 23 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/train_ppo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Implements parameter-efficient ppo training of fine-tuned ChatGLM. 3 | # This code is inspired by: 4 | # https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py 5 | 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torch.optim import AdamW 10 | 11 | from trl import PPOConfig 12 | from trl.core import LengthSampler 13 | 14 | from utils import ( 15 | prepare_args, 16 | prepare_data, 17 | load_pretrained, 18 | preprocess_data, 19 | PPODataCollatorForChatGLM, 20 | PPOTrainerForChatGLM, 21 | compute_rewards, 22 | plot_loss 23 | ) 24 | 25 | 26 | def main(): 27 | 28 | # prepare pretrained model and dataset 29 | model_args, data_args, training_args, finetuning_args = prepare_args() 30 | dataset = prepare_data(model_args, data_args) 31 | model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, training_args.do_train, stage="ppo") 32 | dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo") 33 | data_collator = PPODataCollatorForChatGLM( 34 | tokenizer=tokenizer, 35 | min_input_length=data_args.max_source_length, # avoid truncating input sequences 36 | max_input_length=data_args.max_source_length, 37 | inference_mode=(not training_args.do_train) 38 | ) 39 | 40 | ppo_config = PPOConfig( 41 | model_name=model_args.model_name_or_path, 42 | learning_rate=training_args.learning_rate, 43 | mini_batch_size=max(training_args.per_device_train_batch_size // 4, 1), 44 | batch_size=training_args.per_device_train_batch_size, 45 | gradient_accumulation_steps=training_args.gradient_accumulation_steps, 46 | ppo_epochs=int(training_args.num_train_epochs), 47 | max_grad_norm=training_args.max_grad_norm 48 | ) 49 | 50 | optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate) 51 | 52 | # Initialize our Trainer 53 | ppo_trainer = PPOTrainerForChatGLM( 54 | training_args=training_args, 55 | finetuning_args=finetuning_args, 56 | config=ppo_config, 57 | model=model, 58 | ref_model=None, 59 | tokenizer=tokenizer, 60 | dataset=dataset, 61 | data_collator=data_collator, 62 | optimizer=optimizer 63 | ) 64 | 65 | # Keyword arguments for `model.generate` 66 | gen_kwargs = { 67 | "top_k": 0.0, 68 | "top_p": 1.0, 69 | "do_sample": True, 70 | "pad_token_id": tokenizer.pad_token_id, 71 | "eos_token_id": tokenizer.eos_token_id 72 | } 73 | output_length_sampler = LengthSampler(data_args.max_target_length // 2, data_args.max_target_length) 74 | 75 | for batch in tqdm(ppo_trainer.dataloader): 76 | queries = batch["input_ids"] # left-padded sequences 77 | 78 | model.gradient_checkpointing_disable() 79 | model.config.use_cache = True 80 | 81 | # Get response from ChatGLM 82 | responses_with_queries = ppo_trainer.generate(queries, length_sampler=output_length_sampler, **gen_kwargs) 83 | responses = responses_with_queries[:, queries.size(1):] # right-padded sequences 84 | # batch["response"] = tokenizer.batch_decode(responses, skip_special_tokens=True) # avoid error 85 | 86 | for i in range(responses_with_queries.size(0)): # change to right-padding 87 | start = (responses_with_queries[i] != tokenizer.pad_token_id).nonzero()[0].item() 88 | responses_with_queries[i] = torch.cat((responses_with_queries[i][start:], responses_with_queries[i][:start])) 89 | 90 | # Compute rewards 91 | rewards = compute_rewards(responses_with_queries, model, tokenizer) 92 | 93 | # Run PPO step 94 | model.gradient_checkpointing_enable() 95 | model.config.use_cache = False 96 | split_into_list = lambda x: [x[i] for i in range(x.size(0))] 97 | stats = ppo_trainer.step(*map(split_into_list, [queries, responses, rewards])) 98 | 99 | ppo_trainer.log_stats(stats, batch, rewards) 100 | ppo_trainer.update_stats(stats, batch, rewards) 101 | 102 | ppo_trainer.save_state() # along with the loss values 103 | ppo_trainer.save_model() 104 | if finetuning_args.plot_loss: 105 | plot_loss(training_args) 106 | 107 | 108 | def _mp_fn(index): 109 | # For xla_spawn (TPUs) 110 | main() 111 | 112 | 113 | if __name__ == "__main__": 114 | main() 115 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/train_rm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Implements parameter-efficient training of a reward model based on ChatGLM. 3 | # This code is inspired by: 4 | # https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py 5 | # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py 6 | 7 | 8 | from utils import ( 9 | prepare_args, 10 | prepare_data, 11 | load_pretrained, 12 | preprocess_data, 13 | PairwiseDataCollatorForChatGLM, 14 | PairwiseTrainerForChatGLM, 15 | plot_loss 16 | ) 17 | 18 | def main(): 19 | 20 | # prepare pretrained model and dataset 21 | model_args, data_args, training_args, finetuning_args = prepare_args() 22 | dataset = prepare_data(model_args, data_args) 23 | model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, training_args.do_train, stage="rwd") 24 | dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rwd") 25 | data_collator = PairwiseDataCollatorForChatGLM( 26 | tokenizer=tokenizer, 27 | inference_mode=(not training_args.do_train) 28 | ) 29 | 30 | training_args.remove_unused_columns = False # Important for pairwise dataset 31 | 32 | # Initialize our Trainer 33 | trainer = PairwiseTrainerForChatGLM( 34 | finetuning_args=finetuning_args, 35 | model=model, 36 | args=training_args, 37 | train_dataset=dataset if training_args.do_train else None, 38 | eval_dataset=dataset if training_args.do_eval else None, 39 | tokenizer=tokenizer, 40 | data_collator=data_collator 41 | ) 42 | 43 | # Training 44 | if training_args.do_train: 45 | train_result = trainer.train() 46 | trainer.log_metrics("train", train_result.metrics) 47 | trainer.save_metrics("train", train_result.metrics) 48 | trainer.save_state() # along with the loss values 49 | trainer.save_model() 50 | if finetuning_args.plot_loss: 51 | plot_loss(training_args) 52 | 53 | 54 | def _mp_fn(index): 55 | # For xla_spawn (TPUs) 56 | main() 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import ( 2 | load_pretrained, 3 | prepare_args, 4 | prepare_data, 5 | preprocess_data 6 | ) 7 | 8 | from .seq2seq import ( 9 | Seq2SeqDataCollatorForChatGLM, 10 | ComputeMetrics, 11 | Seq2SeqTrainerForChatGLM 12 | ) 13 | 14 | from .pairwise import ( 15 | PairwiseDataCollatorForChatGLM, 16 | PairwiseTrainerForChatGLM 17 | ) 18 | 19 | from .ppo import ( 20 | PPODataCollatorForChatGLM, 21 | PPOTrainerForChatGLM, 22 | compute_rewards 23 | ) 24 | 25 | from .config import ModelArguments 26 | 27 | from .other import plot_loss 28 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/utils/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import hashlib 5 | import logging 6 | from typing import Literal, Optional, Tuple 7 | 8 | import transformers 9 | from transformers import ( 10 | AutoConfig, 11 | AutoModel, 12 | AutoTokenizer, 13 | HfArgumentParser, 14 | Seq2SeqTrainingArguments 15 | ) 16 | from transformers.utils import check_min_version 17 | from transformers.utils.versions import require_version 18 | from transformers.modeling_utils import PreTrainedModel 19 | from transformers.tokenization_utils import PreTrainedTokenizer 20 | 21 | import datasets 22 | from datasets import Dataset, concatenate_datasets, load_dataset 23 | 24 | from peft import ( 25 | PeftModel, 26 | TaskType, 27 | LoraConfig, 28 | get_peft_model 29 | ) 30 | 31 | from trl import AutoModelForCausalLMWithValueHead 32 | 33 | from .config import ( 34 | ModelArguments, 35 | DataTrainingArguments, 36 | FinetuningArguments 37 | ) 38 | 39 | from .other import ( 40 | load_trainable_params, 41 | load_valuehead_params, 42 | print_trainable_params, 43 | prepare_model_for_training, 44 | IGNORE_INDEX, 45 | FINETUNING_ARGS_NAME 46 | ) 47 | 48 | 49 | logger = logging.getLogger(__name__) # setup logging 50 | logger.setLevel(logging.INFO) 51 | logging.basicConfig( 52 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 53 | datefmt="%m/%d/%Y %H:%M:%S", 54 | handlers=[logging.StreamHandler(sys.stdout)], 55 | ) 56 | 57 | 58 | check_min_version("4.27.4") 59 | require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0") 60 | require_version("peft>=0.3.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") 61 | require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") 62 | 63 | 64 | def init_adapter( 65 | model: PreTrainedModel, 66 | model_args: ModelArguments, 67 | finetuning_args: FinetuningArguments, 68 | is_trainable: bool 69 | ) -> None: 70 | r""" 71 | Initializes the adapters. 72 | 73 | Note that the trainable parameters must be cast to float32. 74 | """ 75 | 76 | if finetuning_args.finetuning_type == "none" and is_trainable: 77 | raise ValueError("You cannot use finetuning_type=none while training.") 78 | 79 | if finetuning_args.finetuning_type == "freeze": 80 | logger.info("Fine-tuning method: Freeze") 81 | for name, param in model.named_parameters(): 82 | if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers): 83 | param.requires_grad_(False) 84 | else: 85 | param.data = param.data.to(torch.float32) 86 | 87 | if model_args.checkpoint_dir is not None: # freeze only accepts a single checkpoint 88 | load_trainable_params(model, model_args.checkpoint_dir[0]) 89 | 90 | if finetuning_args.finetuning_type == "p_tuning": 91 | logger.info("Fine-tuning method: P-Tuning v2") 92 | model.transformer.prefix_encoder.float() # other parameters are already fixed 93 | 94 | if model_args.checkpoint_dir is not None: # p-tuning v2 only accepts a single checkpoint 95 | load_trainable_params(model, model_args.checkpoint_dir[0]) 96 | 97 | if finetuning_args.finetuning_type == "lora": 98 | logger.info("Fine-tuning method: LoRA") 99 | lastest_checkpoint = None 100 | 101 | if model_args.checkpoint_dir is not None: 102 | if is_trainable and finetuning_args.resume_lora_training: # continually training on the lora weights 103 | checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] 104 | else: 105 | checkpoints_to_merge = model_args.checkpoint_dir 106 | 107 | for checkpoint in checkpoints_to_merge: # https://github.com/huggingface/peft/issues/280#issuecomment-1500805831 108 | model = PeftModel.from_pretrained(model, checkpoint) 109 | model = model.merge_and_unload() 110 | 111 | logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) 112 | 113 | if lastest_checkpoint is not None: # resume lora training 114 | model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=True) 115 | 116 | if lastest_checkpoint is None: # create new lora weights 117 | lora_config = LoraConfig( 118 | task_type=TaskType.CAUSAL_LM, 119 | inference_mode=False, 120 | r=finetuning_args.lora_rank, 121 | lora_alpha=finetuning_args.lora_alpha, 122 | lora_dropout=finetuning_args.lora_dropout, 123 | target_modules=finetuning_args.lora_target 124 | ) 125 | model = get_peft_model(model, lora_config) 126 | 127 | if not is_trainable: 128 | for param in model.parameters(): 129 | param.requires_grad_(False) # fix all params 130 | param.data = param.data.to(torch.float16) # cast all params to float16 131 | 132 | return model 133 | 134 | 135 | def load_pretrained( 136 | model_args: ModelArguments, 137 | training_args: Optional[Seq2SeqTrainingArguments] = None, 138 | finetuning_args: Optional[FinetuningArguments] = None, 139 | is_trainable: Optional[bool] = False, 140 | stage: Optional[Literal["sft", "rwd", "ppo"]] = "sft" 141 | ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: 142 | r""" 143 | Load pretrained model and tokenizer. 144 | """ 145 | 146 | if (not is_trainable) and (model_args.checkpoint_dir is None): 147 | logger.warning("Checkpoint is not found at evaluation, load the original model.") 148 | finetuning_args = FinetuningArguments(finetuning_type="none") 149 | 150 | if model_args.checkpoint_dir is not None: # load fine-tuned model from checkpoint 151 | for checkpoint_dir in model_args.checkpoint_dir: 152 | if not os.path.isfile(os.path.join(checkpoint_dir, FINETUNING_ARGS_NAME)): 153 | raise ValueError("The fine-tuning arguments are not found in the provided dictionary.") 154 | logger.info("Load fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) 155 | finetuning_args = torch.load(os.path.join(model_args.checkpoint_dir[0], FINETUNING_ARGS_NAME)) 156 | 157 | quantization = None 158 | if model_args.quantization_bit is not None: 159 | if is_trainable: 160 | if finetuning_args.finetuning_type != "p_tuning": 161 | quantization = "bnb" # use bnb's quantization 162 | else: 163 | quantization = "cpm" # use cpm's quantization 164 | else: 165 | quantization = "cpm" 166 | 167 | config_kwargs = { 168 | "trust_remote_code": True, 169 | "cache_dir": model_args.cache_dir, 170 | "revision": model_args.model_revision, 171 | "use_auth_token": True if model_args.use_auth_token else None, 172 | } 173 | 174 | tokenizer = AutoTokenizer.from_pretrained( 175 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 176 | use_fast=model_args.use_fast_tokenizer, 177 | padding_side="left", 178 | **config_kwargs 179 | ) 180 | 181 | config = AutoConfig.from_pretrained( 182 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 183 | **config_kwargs 184 | ) 185 | 186 | # P-Tuning v2 configurations. 187 | # We use the built-in p-tuning method of ChatGLM, we cannot use PEFT since the attention masks of ChatGLM are unusual. >_< 188 | if finetuning_args.finetuning_type == "p_tuning": 189 | config.pre_seq_len = finetuning_args.pre_seq_len # enable this will fix other parameters automatically 190 | config.prefix_projection = finetuning_args.prefix_projection 191 | 192 | # Quantization configurations for Freeze and LoRA in training (using bitsandbytes library). 193 | if quantization == "bnb": 194 | if model_args.quantization_bit != 8: 195 | raise ValueError("Freeze and LoRA fine-tuning only accept 8-bit quantization.") 196 | require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.") 197 | from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible 198 | cuda = get_cuda_lib_handle() 199 | cc = get_compute_capability(cuda) 200 | if not is_cublasLt_compatible(cc): 201 | raise ValueError("The current GPU(s) is incompatible with quantization.") 202 | config_kwargs["load_in_8bit"] = True 203 | config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit 204 | 205 | # Load and prepare pretrained models (without valuehead). 206 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, **config_kwargs) 207 | model = prepare_model_for_training(model) if is_trainable else model 208 | model = init_adapter(model, model_args, finetuning_args, is_trainable) 209 | 210 | # Quantization with the built-in method for P-Tuning v2 training or evaluation. 211 | # Model parameters should be cast to float16 in quantized P-Tuning setting. 212 | if quantization == "cpm": 213 | if model_args.quantization_bit != 4 and model_args.quantization_bit != 8: 214 | raise ValueError("P-Tuning v2 and inference modes only accept 4-bit or 8-bit quantization.") 215 | 216 | if is_trainable and training_args.fp16: 217 | raise ValueError("FP16 training conflicts with cpm quantization.") 218 | 219 | model = model.quantize(model_args.quantization_bit) 220 | for name, param in model.named_parameters(): 221 | if "prefix_encoder" not in name: 222 | param.data = param.data.to(torch.float16) # convert all params in half precision except prefix_encoder 223 | 224 | if quantization is not None: 225 | logger.info("Quantized model to {} bit.".format(model_args.quantization_bit)) 226 | 227 | if stage == "rwd" or stage == "ppo": # add value head 228 | model = AutoModelForCausalLMWithValueHead.from_pretrained(model) 229 | if stage == "ppo": # load reward model 230 | model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) 231 | load_valuehead_params(model, model_args.reward_model) 232 | 233 | # Set the parameter _is_int8_training_enabled for the AutoModelForCausalLMWithValueHead model 234 | # To meet the compliance requirements of the transformers library 235 | if quantization == "bnb" and model_args.quantization_bit == 8: 236 | model._is_int8_training_enabled = True 237 | 238 | print_trainable_params(model) 239 | 240 | return model, tokenizer 241 | 242 | 243 | def prepare_args() -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]: 244 | 245 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments)) 246 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 247 | # Provide arguments with a json file. 248 | model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 249 | else: 250 | model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() 251 | 252 | # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) 253 | if int(training_args.do_train) + int(training_args.do_eval) + int(training_args.do_predict) != 1: 254 | raise ValueError("We must perform single operation among do_train, do_eval and do_predict.") 255 | 256 | if model_args.quantization_bit is not None and training_args.do_train == False: 257 | logger.warning("We do not recommend to evaluaute model in 4/8-bit mode.") 258 | 259 | if not training_args.fp16: 260 | logger.warning("We recommend enable fp16 mixed precision training for ChatGLM-6B.") 261 | 262 | training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning 263 | 264 | # Set logger 265 | if training_args.should_log: 266 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 267 | transformers.utils.logging.set_verbosity_info() 268 | 269 | log_level = training_args.get_process_log_level() 270 | logger.setLevel(log_level) 271 | datasets.utils.logging.set_verbosity(log_level) 272 | transformers.utils.logging.set_verbosity(log_level) 273 | transformers.utils.logging.enable_default_handler() 274 | transformers.utils.logging.enable_explicit_format() 275 | 276 | # Log on each process the small summary: 277 | logger.warning( 278 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n" 279 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 280 | ) 281 | logger.info(f"Training/evaluation parameters {training_args}") 282 | 283 | # Set seed before initializing model. 284 | transformers.set_seed(training_args.seed) 285 | 286 | return model_args, data_args, training_args, finetuning_args 287 | 288 | 289 | def prepare_data( 290 | model_args: ModelArguments, 291 | data_args: DataTrainingArguments 292 | ) -> Dataset: 293 | 294 | def checksum(file_path, hash): 295 | with open(file_path, "rb") as datafile: 296 | binary_data = datafile.read() 297 | sha1 = hashlib.sha1(binary_data).hexdigest() 298 | if sha1 != hash: 299 | logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) 300 | 301 | max_samples = data_args.max_samples 302 | all_datasets = [] # support multiple datasets 303 | 304 | for dataset_info in data_args.dataset_list: 305 | 306 | logger.info("Loading dataset {}...".format(dataset_info)) 307 | 308 | if dataset_info.load_from == "hf_hub": 309 | raw_datasets = load_dataset(dataset_info.dataset_name, cache_dir=model_args.cache_dir) 310 | elif dataset_info.load_from == "script": 311 | raw_datasets = load_dataset( 312 | os.path.join(data_args.dataset_dir, dataset_info.dataset_name), 313 | cache_dir=model_args.cache_dir 314 | ) 315 | elif dataset_info.load_from == "file": 316 | data_file = os.path.join(data_args.dataset_dir, dataset_info.file_name) # support json, jsonl and csv 317 | extension = dataset_info.file_name.split(".")[-1] 318 | 319 | if dataset_info.file_sha1 is not None: 320 | checksum(data_file, dataset_info.file_sha1) 321 | else: 322 | logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.") 323 | 324 | raw_datasets = load_dataset( 325 | extension, 326 | data_files=data_file, 327 | cache_dir=model_args.cache_dir, 328 | use_auth_token=True if model_args.use_auth_token else None 329 | ) 330 | else: 331 | raise NotImplementedError 332 | 333 | dataset = raw_datasets[data_args.split] 334 | 335 | if max_samples is not None: 336 | max_samples_temp = min(len(dataset), max_samples) 337 | dataset = dataset.select(range(max_samples_temp)) 338 | 339 | dummy_data = [None] * len(dataset) 340 | for column, column_name in [ 341 | ("prompt_column", "prompt"), 342 | ("query_column", "query"), 343 | ("response_column", "response"), 344 | ("history_column", "history") 345 | ]: # every dataset will have 4 columns same as each other 346 | if getattr(dataset_info, column) != column_name: 347 | if getattr(dataset_info, column): 348 | dataset = dataset.rename_column(getattr(dataset_info, column), column_name) 349 | else: # None or empty string 350 | dataset = dataset.add_column(column_name, dummy_data) 351 | all_datasets.append(dataset) 352 | 353 | if len(data_args.dataset_list) == 1: 354 | all_datasets = all_datasets[0] 355 | else: 356 | all_datasets = concatenate_datasets(all_datasets) 357 | 358 | return all_datasets 359 | 360 | 361 | def preprocess_data( 362 | dataset: Dataset, 363 | tokenizer: PreTrainedTokenizer, 364 | data_args: DataTrainingArguments, 365 | training_args: Seq2SeqTrainingArguments, 366 | stage: Optional[Literal["sft", "rwd", "ppo"]] = "sft" 367 | ) -> Dataset: 368 | 369 | column_names = list(dataset.column_names) 370 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 371 | 372 | def format_example(examples): # support question with a single answer or multiple answers 373 | for i in range(len(examples["prompt"])): 374 | if examples["prompt"][i] and examples["response"][i]: 375 | query, answer = examples["prompt"][i], examples["response"][i] 376 | if examples["query"][i]: 377 | query += examples["query"][i] 378 | if examples["history"][i]: 379 | prompt = "" 380 | history = examples["history"][i] 381 | for i, (old_query, response) in enumerate(history): 382 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) 383 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) 384 | else: 385 | prompt = query 386 | prompt = prefix + prompt 387 | yield prompt, answer 388 | 389 | def preprocess_function_train(examples): 390 | # build inputs with format `X [gMASK] [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] [BOS] Y [EOS]` 391 | model_inputs = {"input_ids": [], "labels": []} 392 | for prompt, answer in format_example(examples): 393 | source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) 394 | target_ids = tokenizer.encode(text=answer, add_special_tokens=False) 395 | 396 | if len(source_ids) > data_args.max_source_length - 1: # gmask token 397 | source_ids = source_ids[:data_args.max_source_length - 1] 398 | if len(target_ids) > data_args.max_target_length - 2: # bos and eos tokens 399 | target_ids = target_ids[:data_args.max_target_length - 2] 400 | 401 | input_ids = tokenizer.build_inputs_with_special_tokens(source_ids, target_ids) 402 | 403 | context_length = input_ids.index(tokenizer.bos_token_id) 404 | labels = [IGNORE_INDEX] * context_length + input_ids[context_length:] 405 | 406 | model_inputs["input_ids"].append(input_ids) 407 | model_inputs["labels"].append(labels) 408 | return model_inputs 409 | 410 | def preprocess_function_eval(examples): 411 | # build inputs with format `[PAD] ... [PAD] X [gMASK] [BOS]` and labels with format `Y [gMASK] [BOS]` 412 | # left-padding is needed for prediction, use the built-in function of the tokenizer 413 | inputs, targets = [], [] 414 | for prompt, answer in format_example(examples): 415 | inputs.append(prompt) 416 | targets.append(answer) 417 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True) 418 | labels = tokenizer(text_target=targets, max_length=data_args.max_target_length, truncation=True) # no padding 419 | if data_args.ignore_pad_token_for_loss: 420 | labels["input_ids"] = [ 421 | [(l_id if l_id != tokenizer.pad_token_id else IGNORE_INDEX) for l_id in label] for label in labels["input_ids"] 422 | ] 423 | model_inputs["labels"] = labels["input_ids"] 424 | return model_inputs 425 | 426 | def preprocess_function_train_pair(examples): 427 | # build input pairs with format `X [gMASK] [BOS] Y1 [EOS]` and `X [gMASK] [BOS] Y2 [EOS]` 428 | model_inputs = {"accept_ids": [], "reject_ids": []} 429 | for prompt, answer in format_example(examples): 430 | source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) 431 | accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) 432 | reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) 433 | 434 | if len(source_ids) > data_args.max_source_length - 1: 435 | source_ids = source_ids[:data_args.max_source_length - 1] 436 | if len(accept_ids) > data_args.max_target_length - 2: 437 | accept_ids = accept_ids[:data_args.max_target_length - 2] 438 | if len(reject_ids) > data_args.max_target_length - 2: 439 | reject_ids = reject_ids[:data_args.max_target_length - 2] 440 | 441 | accept_ids = tokenizer.build_inputs_with_special_tokens(source_ids, accept_ids) 442 | reject_ids = tokenizer.build_inputs_with_special_tokens(source_ids, reject_ids) 443 | 444 | model_inputs["accept_ids"].append(accept_ids) 445 | model_inputs["reject_ids"].append(reject_ids) 446 | return model_inputs 447 | 448 | def preprocess_function_train_ppo(examples): 449 | # build inputs with format `X [gMASK] [BOS]` 450 | model_inputs = {"input_ids": []} 451 | for prompt, _ in format_example(examples): 452 | source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) 453 | 454 | if len(source_ids) > data_args.max_source_length - 1: # gmask token 455 | source_ids = source_ids[:data_args.max_source_length - 1] 456 | 457 | input_ids = tokenizer.build_inputs_with_special_tokens(source_ids) 458 | model_inputs["input_ids"].append(input_ids) 459 | return model_inputs 460 | 461 | def print_sft_dataset_example(example): 462 | print("input_ids:\n{}".format(example["input_ids"])) 463 | print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) 464 | print("label_ids:\n{}".format(example["labels"])) 465 | print("labels:\n{}".format(tokenizer.decode(example["labels"]))) 466 | 467 | def print_pairwise_dataset_example(example): 468 | print("accept_ids:\n{}".format(example["accept_ids"])) 469 | print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"]))) 470 | print("reject_ids:\n{}".format(example["reject_ids"])) 471 | print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"]))) 472 | 473 | def print_ppo_dataset_example(example): 474 | print("input_ids:\n{}".format(example["input_ids"])) 475 | print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) 476 | 477 | if stage == "sft": 478 | preprocess_function = preprocess_function_train if training_args.do_train else preprocess_function_eval 479 | elif stage == "rwd": 480 | preprocess_function = preprocess_function_train_pair 481 | elif stage == "ppo": 482 | preprocess_function = preprocess_function_train_ppo 483 | 484 | with training_args.main_process_first(desc="dataset map pre-processing"): 485 | dataset = dataset.map( 486 | preprocess_function, 487 | batched=True, 488 | num_proc=data_args.preprocessing_num_workers, 489 | remove_columns=column_names, 490 | load_from_cache_file=not data_args.overwrite_cache, 491 | desc="Running tokenizer on dataset" 492 | ) 493 | 494 | if stage == "sft": 495 | print_sft_dataset_example(dataset[0]) 496 | elif stage == "rwd": 497 | print_pairwise_dataset_example(dataset[0]) 498 | elif stage == "ppo": 499 | print_ppo_dataset_example(dataset[0]) 500 | 501 | return dataset 502 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import Optional 4 | from dataclasses import dataclass, field 5 | 6 | 7 | CHATGLM_REPO_NAME = "THUDM/chatglm-6b" 8 | CHATGLM_LASTEST_HASH = "a8ede826cf1b62bd3c78bdfb3625c7c5d2048fbd" 9 | 10 | 11 | @dataclass 12 | class DatasetAttr: 13 | 14 | load_from: str 15 | dataset_name: Optional[str] = None 16 | file_name: Optional[str] = None 17 | file_sha1: Optional[str] = None 18 | 19 | def __post_init__(self): 20 | self.prompt_column = "instruction" 21 | self.query_column = "input" 22 | self.response_column = "output" 23 | self.history_column = None 24 | 25 | 26 | @dataclass 27 | class ModelArguments: 28 | """ 29 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune. 30 | """ 31 | model_name_or_path: Optional[str] = field( 32 | default=CHATGLM_REPO_NAME, 33 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} 34 | ) 35 | config_name: Optional[str] = field( 36 | default=None, 37 | metadata={"help": "Pretrained config name or path if not the same as model_name."} 38 | ) 39 | tokenizer_name: Optional[str] = field( 40 | default=None, 41 | metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."} 42 | ) 43 | cache_dir: Optional[str] = field( 44 | default=None, 45 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} 46 | ) 47 | use_fast_tokenizer: Optional[bool] = field( 48 | default=True, 49 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} 50 | ) 51 | model_revision: Optional[str] = field( 52 | default=CHATGLM_LASTEST_HASH, 53 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} 54 | ) 55 | use_auth_token: Optional[bool] = field( 56 | default=False, 57 | metadata={"help": "Will use the token generated when running `huggingface-cli login`."} 58 | ) 59 | quantization_bit: Optional[int] = field( 60 | default=None, 61 | metadata={"help": "The number of bits to quantize the model."} 62 | ) 63 | checkpoint_dir: Optional[str] = field( 64 | default=None, 65 | metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."} 66 | ) 67 | reward_model: Optional[str] = field( 68 | default=None, 69 | metadata={"help": "Path to the directory containing the checkpoints of the reward model."} 70 | ) 71 | 72 | def __post_init__(self): 73 | if self.checkpoint_dir is not None: # support merging lora weights 74 | self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] 75 | 76 | 77 | @dataclass 78 | class DataTrainingArguments: 79 | """ 80 | Arguments pertaining to what data we are going to input our model for training and evaluation. 81 | """ 82 | dataset: Optional[str] = field( 83 | default="alpaca_zh", 84 | metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."} 85 | ) 86 | dataset_dir: Optional[str] = field( 87 | default="data", 88 | metadata={"help": "The name of the folder containing datasets."} 89 | ) 90 | split: Optional[str] = field( 91 | default="train", 92 | metadata={"help": "Which dataset split to use for training and evaluation."} 93 | ) 94 | overwrite_cache: Optional[bool] = field( 95 | default=False, 96 | metadata={"help": "Overwrite the cached training and evaluation sets."} 97 | ) 98 | preprocessing_num_workers: Optional[int] = field( 99 | default=None, 100 | metadata={"help": "The number of processes to use for the preprocessing."} 101 | ) 102 | max_source_length: Optional[int] = field( 103 | default=512, 104 | metadata={"help": "The maximum total input sequence length after tokenization."} 105 | ) 106 | max_target_length: Optional[int] = field( 107 | default=512, 108 | metadata={"help": "The maximum total output sequence length after tokenization."} 109 | ) 110 | max_samples: Optional[int] = field( 111 | default=None, 112 | metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} 113 | ) 114 | num_beams: Optional[int] = field( 115 | default=None, 116 | metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} 117 | ) 118 | ignore_pad_token_for_loss: Optional[bool] = field( 119 | default=True, 120 | metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} 121 | ) 122 | source_prefix: Optional[str] = field( 123 | default=None, 124 | metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 125 | ) 126 | 127 | def __post_init__(self): # support mixing multiple datasets 128 | dataset_names = [ds.strip() for ds in self.dataset.split(",")] 129 | dataset_info = json.load(open(os.path.join(self.dataset_dir, "dataset_info.json"), "r")) 130 | 131 | self.dataset_list = [] 132 | for name in dataset_names: 133 | if name not in dataset_info: 134 | raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) 135 | 136 | if "hf_hub_url" in dataset_info[name]: 137 | dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) 138 | elif "script_url" in dataset_info[name]: 139 | dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) 140 | else: 141 | dataset_attr = DatasetAttr( 142 | "file", 143 | file_name=dataset_info[name]["file_name"], 144 | file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None 145 | ) 146 | 147 | if "columns" in dataset_info[name]: 148 | dataset_attr.prompt_column = dataset_info[name]["columns"]["prompt"] 149 | dataset_attr.query_column = dataset_info[name]["columns"]["query"] 150 | dataset_attr.response_column = dataset_info[name]["columns"]["response"] 151 | dataset_attr.history_column = dataset_info[name]["columns"]["history"] 152 | 153 | self.dataset_list.append(dataset_attr) 154 | 155 | 156 | @dataclass 157 | class FinetuningArguments: 158 | """ 159 | Arguments pertaining to which techniques we are going to fine-tuning with. 160 | """ 161 | finetuning_type: Optional[str] = field( 162 | default="lora", 163 | metadata={"help": "Which fine-tuning method to use."} 164 | ) 165 | num_layer_trainable: Optional[int] = field( 166 | default=3, 167 | metadata={"help": "Number of trainable layers for Freeze fine-tuning."} 168 | ) 169 | name_module_trainable: Optional[str] = field( 170 | default="mlp", 171 | metadata={"help": "Name of trainable modules for Freeze fine-tuning."} 172 | ) 173 | pre_seq_len: Optional[int] = field( 174 | default=16, 175 | metadata={"help": "Number of prefix tokens to use for P-tuning V2."} 176 | ) 177 | prefix_projection: Optional[bool] = field( 178 | default=False, 179 | metadata={"help": "Whether to add a project layer for the prefix in P-tuning V2 or not."} 180 | ) 181 | lora_rank: Optional[int] = field( 182 | default=8, 183 | metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} 184 | ) 185 | lora_alpha: Optional[float] = field( 186 | default=32.0, 187 | metadata={"help": "The scale factor for LoRA fine-tuning. (similar with the learning rate)"} 188 | ) 189 | lora_dropout: Optional[float] = field( 190 | default=0.1, 191 | metadata={"help": "Dropout rate for the LoRA fine-tuning."} 192 | ) 193 | lora_target: Optional[str] = field( 194 | default="query_key_value", 195 | metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules."} 196 | ) 197 | resume_lora_training: Optional[bool] = field( 198 | default=True, 199 | metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} 200 | ) 201 | plot_loss: Optional[bool] = field( 202 | default=False, 203 | metadata={"help": "Whether to plot the training loss after fine-tuning or not."} 204 | ) 205 | 206 | def __post_init__(self): 207 | self.lora_target = [target.strip() for target in self.lora_target.split(",")] # support custom target modules of LoRA 208 | 209 | if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 210 | trainable_layer_ids = [27-k for k in range(self.num_layer_trainable)] 211 | else: # fine-tuning the first n layers if num_layer_trainable < 0 212 | trainable_layer_ids = [k for k in range(-self.num_layer_trainable)] 213 | if self.name_module_trainable == "mlp": 214 | self.trainable_layers = ["layers.{:d}.mlp".format(idx) for idx in trainable_layer_ids] 215 | elif self.name_module_trainable == "qkv": 216 | self.trainable_layers = ["layers.{:d}.attention.query_key_value".format(idx) for idx in trainable_layer_ids] 217 | 218 | if self.finetuning_type not in ["none", "freeze", "p_tuning", "lora"]: 219 | raise NotImplementedError("Invalid fine-tuning method.") 220 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/utils/other.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import logging 6 | from typing import Dict, List, Optional 7 | 8 | from transformers import Seq2SeqTrainingArguments 9 | from transformers.trainer import TRAINER_STATE_NAME 10 | from transformers.modeling_utils import PreTrainedModel 11 | 12 | from peft.utils.other import WEIGHTS_NAME 13 | 14 | 15 | IGNORE_INDEX = -100 16 | VALUE_HEAD_FILE_NAME = "value_head.bin" 17 | FINETUNING_ARGS_NAME = "finetuning_args.bin" 18 | PREDICTION_FILE_NAME = "generated_predictions.txt" 19 | 20 | 21 | logger = logging.getLogger(__name__) # setup logging 22 | logger.setLevel(logging.INFO) 23 | logging.basicConfig( 24 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 25 | datefmt="%m/%d/%Y %H:%M:%S", 26 | handlers=[logging.StreamHandler(sys.stdout)], 27 | ) 28 | 29 | 30 | class AverageMeter: 31 | r""" 32 | Computes and stores the average and current value. 33 | """ 34 | def __init__(self): 35 | self.reset() 36 | 37 | def reset(self): 38 | self.val = 0 39 | self.avg = 0 40 | self.sum = 0 41 | self.count = 0 42 | 43 | def update(self, val, n=1): 44 | self.val = val 45 | self.sum += val * n 46 | self.count += n 47 | self.avg = self.sum / self.count 48 | 49 | 50 | # Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32 51 | # Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35 52 | def prepare_model_for_training( 53 | model: PreTrainedModel, 54 | output_embedding_layer_name: Optional[str] = "lm_head", 55 | use_gradient_checkpointing: Optional[bool] = True, 56 | layer_norm_names: List[str] = ["layernorm"] # for chatglm setting 57 | ) -> PreTrainedModel: 58 | 59 | for name, param in model.named_parameters(): 60 | if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): 61 | param.data = param.data.to(torch.float32) 62 | 63 | if use_gradient_checkpointing: 64 | model.enable_input_require_grads() 65 | model.gradient_checkpointing_enable() 66 | model.config.use_cache = False # turn off when gradient checkpointing is enabled 67 | 68 | if hasattr(model, output_embedding_layer_name): 69 | output_embedding_layer = getattr(model, output_embedding_layer_name) 70 | input_dtype = output_embedding_layer.weight.dtype 71 | 72 | class CastOutputToFloat(torch.nn.Sequential): 73 | 74 | def forward(self, x): 75 | return super().forward(x.to(input_dtype)).to(torch.float32) 76 | 77 | setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer)) 78 | 79 | return model 80 | 81 | 82 | def print_trainable_params(model: torch.nn.Module) -> None: 83 | trainable_params, all_param = 0, 0 84 | for param in model.parameters(): 85 | num_params = param.numel() 86 | # if using DS Zero 3 and the weights are initialized empty 87 | if num_params == 0 and hasattr(param, "ds_numel"): 88 | num_params = param.ds_numel 89 | all_param += num_params 90 | if param.requires_grad: 91 | trainable_params += num_params 92 | print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( 93 | trainable_params, all_param, 100 * trainable_params / all_param)) 94 | 95 | 96 | def filter_model_params(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # filter out freezed parameters 97 | state_dict = model.state_dict() 98 | filtered_state_dict = {} 99 | for k, v in model.named_parameters(): 100 | if v.requires_grad: 101 | filtered_state_dict[k] = state_dict[k] 102 | return filtered_state_dict 103 | 104 | 105 | def save_trainable_params(save_directory: os.PathLike, model: torch.nn.Module) -> None: 106 | if os.path.isfile(save_directory): 107 | raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file.") 108 | os.makedirs(save_directory, exist_ok=True) 109 | filtered_state_dict = filter_model_params(model) 110 | torch.save(filtered_state_dict, os.path.join(save_directory, WEIGHTS_NAME)) 111 | 112 | 113 | def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None: 114 | weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) 115 | if not os.path.exists(weights_file): 116 | raise ValueError(f"Provided path ({checkpoint_dir}) does not contain the pretrained weights.") 117 | model_state_dict = torch.load(weights_file) 118 | model.load_state_dict(model_state_dict, strict=False) # skip missing keys 119 | 120 | 121 | def save_valuehead_params(save_directory: os.PathLike, v_head: torch.nn.Module) -> None: 122 | if os.path.isfile(save_directory): 123 | raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file.") 124 | os.makedirs(save_directory, exist_ok=True) 125 | torch.save(v_head.state_dict(), os.path.join(save_directory, VALUE_HEAD_FILE_NAME)) 126 | 127 | 128 | def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None: 129 | valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME) 130 | if not os.path.exists(valuehead_file): 131 | raise ValueError(f"Provided path ({checkpoint_dir}) does not contain the valuehead weights.") 132 | valuehead_state_dict = torch.load(valuehead_file) 133 | model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"]) 134 | model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"]) 135 | model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"])) 136 | model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) 137 | 138 | 139 | def plot_loss(training_args: Seq2SeqTrainingArguments) -> None: 140 | import matplotlib.pyplot as plt 141 | FIGURE_NAME = "trainer_state.png" 142 | data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r")) 143 | train_steps, train_losses = [], [] 144 | for i in range(len(data["log_history"]) - 1): 145 | train_steps.append(data["log_history"][i]["step"]) 146 | train_losses.append(data["log_history"][i]["loss"]) 147 | plt.figure() 148 | plt.plot(train_steps, train_losses) 149 | plt.title("training loss of {}".format(training_args.output_dir)) 150 | plt.xlabel("step") 151 | plt.ylabel("training loss") 152 | plt.savefig(os.path.join(training_args.output_dir, FIGURE_NAME), format="png", transparent=True, dpi=300) 153 | print("Figure saved: {}".format(os.path.join(training_args.output_dir, FIGURE_NAME))) 154 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/utils/pairwise.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import logging 5 | from typing import Dict, Optional, Sequence 6 | 7 | from transformers import Trainer, DataCollatorWithPadding 8 | from transformers.trainer import TRAINING_ARGS_NAME 9 | from transformers.tokenization_utils import PreTrainedTokenizer 10 | 11 | from .config import FinetuningArguments 12 | 13 | from .other import ( 14 | save_trainable_params, 15 | save_valuehead_params, 16 | FINETUNING_ARGS_NAME 17 | ) 18 | 19 | 20 | logger = logging.getLogger(__name__) # setup logging 21 | logger.setLevel(logging.INFO) 22 | logging.basicConfig( 23 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 24 | datefmt="%m/%d/%Y %H:%M:%S", 25 | handlers=[logging.StreamHandler(sys.stdout)], 26 | ) 27 | 28 | 29 | class PairwiseDataCollatorForChatGLM(DataCollatorWithPadding): 30 | r""" 31 | Data collator for ChatGLM. It is capable of dynamically padding for batched data. 32 | 33 | Inspired by: https://github.com/tatsu-lab/stanford_alpaca/blob/65512697dc67779a6e53c267488aba0ec4d7c02a/train.py#L156 34 | """ 35 | def __init__( 36 | self, 37 | tokenizer: PreTrainedTokenizer, 38 | inference_mode: bool = False 39 | ): 40 | super().__init__(tokenizer, padding=True) 41 | self.inference_mode = inference_mode 42 | 43 | def __call__(self, features: Sequence[Dict[str, Sequence]]) -> Dict[str, torch.Tensor]: 44 | r""" 45 | Pads batched data to the longest sequence in the batch. We adopt right-padding for pairwise data. 46 | 47 | ChatGLM is able to generate attentions masks and position ids by itself. 48 | """ 49 | if self.inference_mode: 50 | raise NotImplementedError 51 | accept_ids, reject_ids = [[torch.tensor(feature[key]) for feature in features] for key in ("accept_ids", "reject_ids")] 52 | accept_ids = torch.nn.utils.rnn.pad_sequence(accept_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) 53 | reject_ids = torch.nn.utils.rnn.pad_sequence(reject_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) 54 | features = {"accept_ids": accept_ids, "reject_ids": reject_ids} 55 | return features 56 | 57 | class PairwiseTrainerForChatGLM(Trainer): 58 | r""" 59 | Inherits Trainer to compute pairwise loss. 60 | """ 61 | 62 | def __init__(self, finetuning_args: FinetuningArguments, *args, **kwargs): 63 | super().__init__(*args, **kwargs) 64 | self.finetuning_args = finetuning_args 65 | 66 | def compute_loss(self, model, inputs, return_outputs=False): 67 | r""" 68 | Computes pairwise loss. 69 | 70 | There are two different implmentations: 71 | [1] https://github.com/lvwerra/trl/blob/52fecee8839ad826ad1e6c83a95c99a4116e98d2/examples/summarization/scripts/reward_summarization.py#L181 72 | [2] https://github.com/microsoft/DeepSpeedExamples/blob/f4ad1d5721630185a9088565f9201929a8b1ffdf/applications/DeepSpeed-Chat/training/utils/model/reward_model.py#L37 73 | Now we adopt the first implementation. We will consider adopting the second implementation later. 74 | """ 75 | _, _, r_accept = model(input_ids=inputs["accept_ids"]) 76 | _, _, r_reject = model(input_ids=inputs["reject_ids"]) 77 | s_accept = r_accept.transpose(0, 1)[(inputs["accept_ids"] == self.tokenizer.eos_token_id).nonzero(as_tuple=True)] 78 | s_reject = r_reject.transpose(0, 1)[(inputs["reject_ids"] == self.tokenizer.eos_token_id).nonzero(as_tuple=True)] 79 | loss = -torch.log(torch.sigmoid(s_accept - s_reject)).mean() 80 | if return_outputs: 81 | return loss, {"r_accept": r_accept, "r_reject": r_reject} 82 | return loss 83 | 84 | def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None: 85 | r""" 86 | Saves trainable parameters as model checkpoints. Use `self.model.pretrained_model` to refer to the backbone model. 87 | 88 | Override to inject custom behavior. 89 | """ 90 | output_dir = output_dir if output_dir is not None else self.args.output_dir 91 | os.makedirs(output_dir, exist_ok=True) 92 | logger.info(f"Saving model checkpoint to {output_dir}") 93 | if hasattr(self.model.pretrained_model, "peft_config"): # LoRA 94 | self.model.pretrained_model.save_pretrained(output_dir) # only save peft weights with the built-in method 95 | else: # Freeze and P-Tuning 96 | save_trainable_params(output_dir, self.model.pretrained_model) 97 | if hasattr(self.model, "v_head"): 98 | save_valuehead_params(output_dir, self.model.v_head) # save valuehead weights 99 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 100 | torch.save(self.finetuning_args, os.path.join(output_dir, FINETUNING_ARGS_NAME)) 101 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/utils/ppo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import logging 6 | from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple 7 | 8 | from transformers import DataCollatorWithPadding, Seq2SeqTrainingArguments 9 | from transformers.trainer import TRAINING_ARGS_NAME, TRAINER_STATE_NAME 10 | from transformers.tokenization_utils import PreTrainedTokenizer 11 | 12 | from trl import PPOTrainer, AutoModelForCausalLMWithValueHead 13 | from trl.core import LengthSampler 14 | from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits 15 | 16 | from .config import FinetuningArguments 17 | 18 | from .other import ( 19 | AverageMeter, 20 | save_trainable_params, 21 | save_valuehead_params, 22 | FINETUNING_ARGS_NAME 23 | ) 24 | 25 | 26 | logger = logging.getLogger(__name__) # setup logging 27 | logger.setLevel(logging.INFO) 28 | logging.basicConfig( 29 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 30 | datefmt="%m/%d/%Y %H:%M:%S", 31 | handlers=[logging.StreamHandler(sys.stdout)], 32 | ) 33 | 34 | 35 | def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None: 36 | if target == "reward": 37 | valuehead_state_dict = model.v_head.state_dict() 38 | 39 | setattr(model, "origin_head_weight", valuehead_state_dict["summary.weight"]) 40 | setattr(model, "origin_head_bias", valuehead_state_dict["summary.bias"]) 41 | 42 | model.pretrained_model.set_adapter(target) 43 | model.v_head.load_state_dict({ 44 | "summary.weight": getattr(model, "{}_head_weight".format(target)), 45 | "summary.bias": getattr(model, "{}_head_bias".format(target)) 46 | }) 47 | 48 | 49 | @torch.no_grad() 50 | def compute_rewards( 51 | input_ids: torch.Tensor, # (batch size x seq len) with format `X [gMASK] [BOS] Y [EOS] [PAD] ... [PAD]` 52 | model: AutoModelForCausalLMWithValueHead, 53 | tokenizer: PreTrainedTokenizer 54 | ) -> torch.Tensor: 55 | 56 | replace_model(model, target="reward") 57 | 58 | _, _, values = model(input_ids=input_ids) 59 | values = values.transpose(0, 1) 60 | 61 | rewards = [] 62 | for i in range(input_ids.size(0)): 63 | eos_idx = (input_ids[i] == tokenizer.eos_token_id).nonzero() # Note: checking with eos_token is unsafe 64 | if len(eos_idx): 65 | eos_idx = eos_idx[0].item() 66 | else: 67 | eos_idx = input_ids.size(1) - 1 68 | rewards.append(values[i][eos_idx]) 69 | rewards = torch.stack(rewards, dim=0) 70 | 71 | replace_model(model, target="default") 72 | 73 | return rewards 74 | 75 | 76 | def cast_layernorm_dtype( 77 | model: AutoModelForCausalLMWithValueHead, 78 | layer_norm_names: List[str] = ["layernorm"], # for chatglm setting 79 | layer_norm_params: Optional[Dict[str, torch.Tensor]] = None 80 | ) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: 81 | 82 | layer_norm_state_dict = {} 83 | 84 | for name, param in model.named_parameters(): 85 | if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): 86 | if layer_norm_params is not None: 87 | param.data = layer_norm_params[name] # restore float32 weights 88 | else: 89 | layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability 90 | param.data = param.data.to(torch.float16) 91 | 92 | return model, layer_norm_state_dict 93 | 94 | 95 | class PPODataCollatorForChatGLM(DataCollatorWithPadding): 96 | r""" 97 | Data collator for ChatGLM. It is capable of dynamically padding for batched data. 98 | 99 | Inspired by: https://github.com/tatsu-lab/stanford_alpaca/blob/65512697dc67779a6e53c267488aba0ec4d7c02a/train.py#L156 100 | """ 101 | def __init__( 102 | self, 103 | tokenizer: PreTrainedTokenizer, 104 | min_input_length: int, 105 | max_input_length: int, 106 | inference_mode: bool = False, 107 | ): 108 | super().__init__(tokenizer, padding=True) 109 | self.inference_mode = inference_mode 110 | if min_input_length < max_input_length: 111 | self.input_size = LengthSampler(min_input_length, max_input_length) 112 | else: 113 | self.input_size = lambda: max_input_length # always use max_input_length 114 | 115 | def __call__(self, features: Sequence[Dict[str, Sequence]]) -> Dict[str, torch.Tensor]: 116 | r""" 117 | Pads batched data to the longest sequence in the batch. We adopt left-padding for ppo data. 118 | 119 | Equips with a length sampler to generate sequences with variable lengths. 120 | 121 | ChatGLM is able to generate attentions masks and position ids by itself. 122 | """ 123 | if self.inference_mode: 124 | raise NotImplementedError 125 | input_ids = [torch.tensor(feature["input_ids"][:self.input_size()]).flip(0) for feature in features] 126 | input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) 127 | features = {"input_ids": input_ids.flip(-1)} 128 | return features 129 | 130 | class PPOTrainerForChatGLM(PPOTrainer): 131 | r""" 132 | Inherits PPOTrainer. 133 | """ 134 | 135 | def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, *args, **kwargs): 136 | super().__init__(*args, **kwargs) 137 | self.steps = 0 138 | self.loss_meter = AverageMeter() 139 | self.reward_meter = AverageMeter() 140 | self.trainer_state = {"log_history": []} 141 | self.training_args = training_args 142 | self.finetuning_args = finetuning_args 143 | 144 | def generate( 145 | self, 146 | query_tensor: torch.Tensor, # (batch size x seq len) 147 | length_sampler: Callable = None, 148 | return_prompt: bool = True, 149 | **generation_kwargs, 150 | ) -> torch.Tensor: 151 | r""" 152 | Generate response with the model given the query tensor. 153 | 154 | Inspired by: https://github.com/lvwerra/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/trl/trainer/ppo_trainer.py#L387 155 | """ 156 | 157 | self.model, layer_norm_params = cast_layernorm_dtype(self.model) 158 | 159 | if length_sampler is not None: 160 | generation_kwargs["max_new_tokens"] = length_sampler() 161 | 162 | response = self.accelerator.unwrap_model(self.model).generate( 163 | input_ids=query_tensor, **generation_kwargs 164 | ) 165 | 166 | # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop 167 | # Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273 168 | if self.model.pretrained_model.generation_config._from_model_config: 169 | self.model.pretrained_model.generation_config._from_model_config = False 170 | 171 | self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params) 172 | 173 | if not return_prompt and not self.is_encoder_decoder: 174 | return response[:, query_tensor.size(1):] 175 | return response 176 | 177 | def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor): 178 | input_ids = [] 179 | for query, response in zip(queries, responses): # query is left-padded, response is right-padded 180 | start = (query != self.tokenizer.pad_token_id).nonzero()[0].item() 181 | input_ids.append(torch.cat((query[start:], response, query[:start]))) # change to right-padding 182 | 183 | input_data = self.data_collator([{"input_ids": ids} for ids in input_ids]).to(self.current_device) 184 | input_data.pop("labels", None) # we don't want to compute LM losses 185 | 186 | return input_data 187 | 188 | @PPODecorators.empty_cuda_cache() 189 | def batched_forward_pass( 190 | self, 191 | model: AutoModelForCausalLMWithValueHead, 192 | queries: torch.Tensor, 193 | responses: torch.Tensor, 194 | model_inputs: dict, 195 | ): 196 | r""" 197 | Calculate model outputs in multiple batches. 198 | 199 | Override to inject custom behavior. 200 | """ 201 | bs = len(queries) 202 | fbs = self.config.mini_batch_size 203 | all_logprobs = [] 204 | all_logits = [] 205 | all_masks = [] 206 | all_values = [] 207 | 208 | for i in range(int(bs / fbs)): 209 | input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} 210 | 211 | input_ids = input_kwargs["input_ids"] 212 | logits, _, values = model(input_ids=input_ids) # chatglm only needs input_ids 213 | logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) 214 | 215 | values = values.transpose(0, 1) 216 | masks = torch.zeros_like(input_ids) 217 | 218 | for j in range(fbs): 219 | start = (input_ids[j] == self.tokenizer.bos_token_id).nonzero()[0].item() 220 | end = (input_ids[j] == self.tokenizer.eos_token_id).nonzero() 221 | if len(end): 222 | end = end[0].item() 223 | else: 224 | end = masks.size(1) 225 | masks[j][start:end] = 1 226 | 227 | all_logits.append(logits) 228 | all_values.append(values) 229 | all_logprobs.append(logprobs) 230 | all_masks.append(masks) 231 | 232 | return ( 233 | torch.cat(all_logprobs), 234 | torch.cat(all_logits)[:, :-1], 235 | torch.cat(all_values)[:, :-1], 236 | torch.cat(all_masks)[:, :-1], 237 | ) 238 | 239 | def update_stats(self, stats: Dict[str, Any], batch: Dict[str, torch.Tensor], rewards: torch.Tensor) -> None: 240 | self.steps += 1 241 | self.loss_meter.update(stats["ppo/loss/total"]) 242 | self.reward_meter.update(rewards.sum().item(), n=rewards.size(0)) 243 | if self.steps % self.training_args.logging_steps == 0: 244 | print("{{'loss': {:.4f}, 'reward': {:.4f}, 'learning_rate': {:}}}".format( 245 | self.loss_meter.avg, self.reward_meter.avg, stats["ppo/learning_rate"] 246 | )) 247 | self.trainer_state["log_history"].append({ 248 | "loss": self.loss_meter.avg, 249 | "reward": self.reward_meter.avg, 250 | "step": self.steps 251 | }) 252 | self.loss_meter.reset() 253 | self.reward_meter.reset() 254 | 255 | def save_state(self, output_dir: Optional[str] = None) -> None: 256 | r""" 257 | Saves trainer state. 258 | """ 259 | output_dir = output_dir if output_dir is not None else self.training_args.output_dir 260 | os.makedirs(output_dir, exist_ok=True) 261 | json.dump(self.trainer_state, open(os.path.join(output_dir, TRAINER_STATE_NAME), "w", encoding="utf-8", newline="\n")) 262 | 263 | def save_model(self, output_dir: Optional[str] = None) -> None: 264 | r""" 265 | Saves trainable parameters as model checkpoints. We use `self.model.pretrained_model` to refer to the backbone model. 266 | 267 | Override to inject custom behavior. 268 | """ 269 | output_dir = output_dir if output_dir is not None else self.training_args.output_dir 270 | os.makedirs(output_dir, exist_ok=True) 271 | logger.info(f"Saving model checkpoint to {output_dir}") 272 | if hasattr(self.model.pretrained_model, "peft_config"): # LoRA 273 | self.model.pretrained_model.save_pretrained(output_dir) # only save peft weights with the built-in method 274 | else: # Freeze and P-Tuning 275 | save_trainable_params(output_dir, self.model.pretrained_model) 276 | if hasattr(self.model, "v_head"): 277 | save_valuehead_params(output_dir, self.model.v_head) # save valuehead weights 278 | torch.save(self.training_args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 279 | torch.save(self.finetuning_args, os.path.join(output_dir, FINETUNING_ARGS_NAME)) 280 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/utils/seq2seq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import logging 6 | import numpy as np 7 | from dataclasses import dataclass 8 | from typing import Any, Dict, List, Optional, Sequence, Tuple, Union 9 | 10 | from transformers import Seq2SeqTrainer, DataCollatorForSeq2Seq 11 | from transformers.trainer import PredictionOutput, TRAINING_ARGS_NAME 12 | from transformers.deepspeed import is_deepspeed_zero3_enabled 13 | from transformers.modeling_utils import PreTrainedModel 14 | from transformers.tokenization_utils import PreTrainedTokenizer 15 | 16 | import jieba 17 | from rouge_chinese import Rouge 18 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 19 | 20 | from .config import FinetuningArguments 21 | 22 | from .other import ( 23 | save_trainable_params, 24 | IGNORE_INDEX, 25 | FINETUNING_ARGS_NAME, 26 | PREDICTION_FILE_NAME 27 | ) 28 | 29 | 30 | logger = logging.getLogger(__name__) # setup logging 31 | logger.setLevel(logging.INFO) 32 | logging.basicConfig( 33 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 34 | datefmt="%m/%d/%Y %H:%M:%S", 35 | handlers=[logging.StreamHandler(sys.stdout)], 36 | ) 37 | 38 | 39 | # Note: The ChatGLM tokenizer assigns False on token to be attended in attention mask. In general settings, it should be True. 40 | # Refer to: https://huggingface.co/THUDM/chatglm-6b/blob/6650ae3a53c28fc176d06762ca80b05d5ab3792b/tokenization_chatglm.py#L401 41 | class Seq2SeqDataCollatorForChatGLM(DataCollatorForSeq2Seq): 42 | r""" 43 | Data collator for ChatGLM. It is capable of dynamically padding for batched data. 44 | 45 | Inspired by: https://github.com/tatsu-lab/stanford_alpaca/blob/65512697dc67779a6e53c267488aba0ec4d7c02a/train.py#L156 46 | """ 47 | def __init__( 48 | self, 49 | tokenizer: PreTrainedTokenizer, 50 | model: PreTrainedModel, 51 | ignore_pad_token_for_loss: bool, 52 | inference_mode: bool = False 53 | ): 54 | label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id 55 | super().__init__(tokenizer, model=model, label_pad_token_id=label_pad_token_id, padding=True) 56 | self.label_pad_token_id = label_pad_token_id 57 | self.inference_mode = inference_mode 58 | 59 | def __call__(self, features: Sequence[Dict[str, Sequence]]) -> Dict[str, torch.Tensor]: 60 | r""" 61 | Pads batched data to the longest sequence in the batch. 62 | 63 | ChatGLM is able to generate attentions masks and position ids by itself. 64 | """ 65 | if self.inference_mode: # evaluation set adopts left-padding while training set adopts right-padding 66 | return super().__call__(features) 67 | input_ids, labels = [[torch.tensor(feature[key]) for feature in features] for key in ("input_ids", "labels")] 68 | input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) 69 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=self.label_pad_token_id) 70 | features = {"input_ids": input_ids, "labels": labels} 71 | return features 72 | 73 | 74 | @dataclass 75 | class ComputeMetrics: 76 | r""" 77 | Wraps the tokenizer into metric functions, used in Seq2SeqTrainerForChatGLM. 78 | 79 | Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307 80 | """ 81 | 82 | tokenizer: PreTrainedTokenizer 83 | 84 | def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: 85 | r""" 86 | Uses the model predictions to compute metrics. 87 | """ 88 | preds, labels = eval_preds 89 | if isinstance(preds, tuple): 90 | preds = preds[0] 91 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) 92 | # Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them if ignore_pad_token_for_loss=True. 93 | labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) 94 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) 95 | 96 | score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} 97 | for pred, label in zip(decoded_preds, decoded_labels): 98 | hypothesis = list(jieba.cut(pred)) 99 | reference = list(jieba.cut(label)) 100 | 101 | if len(" ".join(hypothesis).split()) == 0: 102 | result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} 103 | else: 104 | rouge = Rouge() 105 | scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) 106 | result = scores[0] 107 | 108 | for k, v in result.items(): 109 | score_dict[k].append(round(v["f"] * 100, 4)) 110 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) 111 | score_dict["bleu-4"].append(round(bleu_score * 100, 4)) 112 | 113 | return {k: float(np.mean(v)) for k, v in score_dict.items()} 114 | 115 | 116 | class Seq2SeqTrainerForChatGLM(Seq2SeqTrainer): 117 | r""" 118 | Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE. 119 | """ 120 | 121 | def __init__(self, finetuning_args: FinetuningArguments, *args, **kwargs): 122 | super().__init__(*args, **kwargs) 123 | self.finetuning_args = finetuning_args 124 | 125 | def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None: 126 | r""" 127 | Saves trainable parameters as model checkpoints. 128 | 129 | Override to inject custom behavior. 130 | """ 131 | output_dir = output_dir if output_dir is not None else self.args.output_dir 132 | os.makedirs(output_dir, exist_ok=True) 133 | logger.info(f"Saving model checkpoint to {output_dir}") 134 | if hasattr(self.model, "peft_config"): # LoRA 135 | self.model.save_pretrained(output_dir) # only save peft weights with the built-in method 136 | else: 137 | save_trainable_params(output_dir, self.model) # Freeze and P-Tuning 138 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 139 | torch.save(self.finetuning_args, os.path.join(output_dir, FINETUNING_ARGS_NAME)) 140 | 141 | def prediction_step( 142 | self, 143 | model: torch.nn.Module, 144 | inputs: Dict[str, Union[torch.Tensor, Any]], 145 | prediction_loss_only: bool, 146 | ignore_keys: Optional[List[str]] = None 147 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 148 | r""" 149 | Performs an evaluation step on `model` using `inputs` for ChatGLM. 150 | 151 | Override to inject custom behavior. It is not directly used by external scripts. 152 | """ 153 | # Override to inject custom bevavior. 154 | if not self.args.predict_with_generate or prediction_loss_only: 155 | return super().prediction_step( 156 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 157 | ) 158 | 159 | has_labels = "labels" in inputs 160 | inputs = self._prepare_inputs(inputs) 161 | 162 | gen_kwargs = self._gen_kwargs.copy() 163 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 164 | gen_kwargs["max_length"] = self.model.config.max_length 165 | gen_kwargs["num_beams"] = gen_kwargs["num_beams"] \ 166 | if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams 167 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False 168 | gen_kwargs["synced_gpus"] = gen_kwargs["synced_gpus"] \ 169 | if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus 170 | 171 | if "attention_mask" in inputs: 172 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) 173 | if "position_ids" in inputs: 174 | gen_kwargs["position_ids"] = inputs.get("position_ids", None) 175 | if "global_attention_mask" in inputs: 176 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) 177 | 178 | # prepare generation inputs 179 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: 180 | generation_inputs = inputs[self.model.encoder.main_input_name] 181 | else: 182 | generation_inputs = inputs[self.model.main_input_name] 183 | 184 | gen_kwargs["input_ids"] = generation_inputs 185 | generated_tokens = self.model.generate(**gen_kwargs) 186 | generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] # important for ChatGLM 187 | 188 | # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop 189 | # Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273 190 | if self.model.generation_config._from_model_config: 191 | self.model.generation_config._from_model_config = False 192 | 193 | # Retrieves GenerationConfig from model.generation_config 194 | gen_config = self.model.generation_config 195 | # in case the batch is shorter than max length, the output should be padded 196 | if generated_tokens.shape[-1] < gen_config.max_length: 197 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length) 198 | elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1: 199 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1) 200 | 201 | loss = None 202 | 203 | if self.args.prediction_loss_only: 204 | return loss, None, None 205 | 206 | if has_labels: 207 | labels = inputs["labels"] 208 | if labels.shape[-1] < gen_config.max_length: 209 | labels = self._pad_tensors_to_max_len(labels, gen_config.max_length) 210 | elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1: 211 | labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1) 212 | else: 213 | labels = None 214 | 215 | return loss, generated_tokens, labels 216 | 217 | def save_predictions( 218 | self, 219 | predict_results: PredictionOutput, 220 | tokenizer: PreTrainedTokenizer 221 | ) -> None: 222 | r""" 223 | Saves model predictions to `output_dir`. 224 | 225 | A custom behavior that not contained in Seq2SeqTrainer. 226 | """ 227 | if self.is_world_process_zero(): 228 | if self.args.predict_with_generate: 229 | predictions = tokenizer.batch_decode(predict_results.predictions, skip_special_tokens=True) 230 | predictions = [pred.strip() for pred in predictions] 231 | labels = tokenizer.batch_decode(predict_results.label_ids, skip_special_tokens=True) 232 | labels = [label.strip() for label in labels] 233 | output_prediction_file = os.path.join(self.args.output_dir, PREDICTION_FILE_NAME) 234 | logger.info(f"Saving prediction results to {output_prediction_file}") 235 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 236 | res = [] 237 | for pred, label in zip(predictions, labels): 238 | res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) 239 | writer.write("\n".join(res)) 240 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/web_demo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Implement user interface in browser for ChatGLM fine-tuned with PEFT. 3 | # This code is largely borrowed from https://github.com/THUDM/ChatGLM-6B/blob/main/web_demo.py 4 | 5 | 6 | import gradio as gr 7 | import mdtex2html 8 | 9 | from utils import ModelArguments, load_pretrained 10 | from transformers import HfArgumentParser 11 | 12 | 13 | parser = HfArgumentParser(ModelArguments) 14 | model_args, = parser.parse_args_into_dataclasses() 15 | model, tokenizer = load_pretrained(model_args) 16 | model = model.cuda() 17 | model.eval() 18 | 19 | 20 | """Override Chatbot.postprocess""" 21 | 22 | def postprocess(self, y): 23 | if y is None: 24 | return [] 25 | for i, (message, response) in enumerate(y): 26 | y[i] = ( 27 | None if message is None else mdtex2html.convert((message)), 28 | None if response is None else mdtex2html.convert(response), 29 | ) 30 | return y 31 | 32 | 33 | gr.Chatbot.postprocess = postprocess 34 | 35 | 36 | def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT 37 | lines = text.split("\n") 38 | lines = [line for line in lines if line != ""] 39 | count = 0 40 | for i, line in enumerate(lines): 41 | if "```" in line: 42 | count += 1 43 | items = line.split('`') 44 | if count % 2 == 1: 45 | lines[i] = f'
'
 46 |             else:
 47 |                 lines[i] = f'
' 48 | else: 49 | if i > 0: 50 | if count % 2 == 1: 51 | line = line.replace("`", "\`") 52 | line = line.replace("<", "<") 53 | line = line.replace(">", ">") 54 | line = line.replace(" ", " ") 55 | line = line.replace("*", "*") 56 | line = line.replace("_", "_") 57 | line = line.replace("-", "-") 58 | line = line.replace(".", ".") 59 | line = line.replace("!", "!") 60 | line = line.replace("(", "(") 61 | line = line.replace(")", ")") 62 | line = line.replace("$", "$") 63 | lines[i] = "
"+line 64 | text = "".join(lines) 65 | return text 66 | 67 | 68 | def predict(input, chatbot, max_length, top_p, temperature, history): 69 | chatbot.append((parse_text(input), "")) 70 | for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p, 71 | temperature=temperature): 72 | chatbot[-1] = (parse_text(input), parse_text(response)) 73 | 74 | yield chatbot, history 75 | 76 | 77 | def reset_user_input(): 78 | return gr.update(value='') 79 | 80 | 81 | def reset_state(): 82 | return [], [] 83 | 84 | 85 | with gr.Blocks() as demo: 86 | gr.HTML("""

MedQA-ChatGLM

""") 87 | 88 | chatbot = gr.Chatbot() 89 | with gr.Row(): 90 | with gr.Column(scale=4): 91 | with gr.Column(scale=12): 92 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( 93 | container=False) 94 | with gr.Column(min_width=32, scale=1): 95 | submitBtn = gr.Button("Submit", variant="primary") 96 | with gr.Column(scale=1): 97 | emptyBtn = gr.Button("Clear History") 98 | max_length = gr.Slider(0, 4096, value=4096, step=1.0, label="Maximum length", interactive=True) 99 | top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) 100 | temperature = gr.Slider(0, 1, value=0.01, step=0.01, label="Temperature", interactive=True) 101 | 102 | history = gr.State([]) 103 | 104 | submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], 105 | show_progress=True) 106 | submitBtn.click(reset_user_input, [], [user_input]) 107 | 108 | emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) 109 | 110 | demo.queue().launch(server_name="0.0.0.0", share=False, inbrowser=True) 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | > **Note** 2 | > 3 | > **欢迎关注我们最新的工作:CareLlama (关怀羊驼),它是一个医疗大语言模型,同时它集合了数十个公开可用的医疗微调数据集和开放可用的医疗大语言模型以促进医疗LLM快速发展:https://github.com/WangRongsheng/CareLlama** 4 | 5 | # MedQA-ChatGLM 1 6 | 7 | ![](./images/model.png) 8 | 9 | 1 使用的数据为[cMedQA2](https://github.com/zhangsheng93/cMedQA2) 10 | 11 | # 资源 12 | 13 | |项目|数据集|底座模型| 14 | |:-|:-|:-| 15 | |[ChatMed](https://github.com/michael-wzhu/ChatMed)|[Consult](https://huggingface.co/michaelwzhu/ChatMed-Consult) 包含50w+在线问诊+ChatGPT回复,TCM中医药诊疗数据集未公开|LLaMA-7B| 16 | |[ChatDoctor](https://github.com/Kent0n-Li/ChatDoctor)|[HealthCareMagic-100k](https://drive.google.com/file/d/1lyfqIwlLSClhgrCutWuEe_IACNq6XNUt/view?usp=sharing) 包含100k+真实患者与医生对话数据集,[icliniq-10k](https://drive.google.com/file/d/1ZKbqgYqWc7DJHs3N9TQYQVPdDQmZaClA/view?usp=sharing) 包含10k+患者与医生对话数据集,[GenMedGPT-5k](https://drive.google.com/file/d/1nDTKZ3wZbZWTkFMBkxlamrzbNz0frugg/view?usp=sharing) 包含5k+由GPT生成的医患对话数据集|LLaMA-7B| 17 | |[Med-ChatGLM](https://github.com/SCIR-HI/Med-ChatGLM)|[Huatuo-data](https://huggingface.co/datasets/wangrongsheng/Huatuo-data) 、[Huatuo-liver-cancer](https://huggingface.co/datasets/wangrongsheng/Huatuo-liver-cancer)|ChatGLM-6B| 18 | |[Huatuo-Llama-Med-Chinese](https://github.com/SCIR-HI/Huatuo-Llama-Med-Chinese)|[Huatuo-data](https://huggingface.co/datasets/wangrongsheng/Huatuo-data) 、[Huatuo-liver-cancer](https://huggingface.co/datasets/wangrongsheng/Huatuo-liver-cancer)|LLaMA-7B| 19 | |[DoctorGLM](https://github.com/xionghonglin/DoctorGLM)|[CMD.](https://huggingface.co/datasets/wangrongsheng/CMD-merged) 、[MedDialog](https://huggingface.co/datasets/wangrongsheng/MedDialog-1.1M) 、ChatDoctor项目数据集|ChatGLM-6B| 20 | |[MedicalGPT-zh](https://github.com/MediaBrain-SJTU/MedicalGPT-zh)|数据未开源|ChatGLM-6B| 21 | |[Dr.LLaMA](https://github.com/zguo0525/Dr.LLaMA)||LLaMA| 22 | |[Medical_NLP](https://github.com/FreedomIntelligence/Medical_NLP) 2|-|-| 23 | |[CMCQA](https://github.com/WENGSYX/CMCQA) 3|-|-| 24 | |[QiZhenGPT](https://github.com/CMKRG/QiZhenGPT)|-|-| 25 | |[LLM-Pretrain-FineTune](https://github.com/NLPxiaoxu/LLM-Pretrain-FineTune)|-|-| 26 | |[PMC-LLaMA](https://github.com/chaoyi-wu/PMC-LLaMA)|-|LLaMA-7B| 27 | |[BianQue](https://github.com/scutcyr/BianQue)|-|-| 28 | |[medAlpaca](https://github.com/kbressem/medAlpaca)|-|LLaMA-7B| 29 | |[MedicalGPT](https://github.com/shibing624/MedicalGPT)|-|-| 30 | |[LLM-Pretrain-FineTune](https://github.com/X-jun-0130/LLM-Pretrain-FineTune)|-|-| 31 | |[ShenNong-TCM-LLM](https://github.com/michael-wzhu/ShenNong-TCM-LLM)|-|-| 32 | |[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)|-|-| 33 | |[CMLM-ZhongJing](https://github.com/pariskang/CMLM-ZhongJing)|-|-| 34 | |[ZhongJing](https://github.com/SupritYoung/Zhongjing)|-|-| 35 | |[Ming](https://github.com/MediaBrain-SJTU/MING)|-|-| 36 | |[DISC-MedLLM](https://github.com/FudanDISC/DISC-MedLLM)|-|-| 37 | 38 | 39 | - 2 为相关医学的大模型资源,请务必格外关注[FreedomIntelligence](https://github.com/FreedomIntelligence) 40 | - 3 来自中国医学对话问答网站春雨,在男科、耳科、妇产科等45个科室医学对话材料 41 | - https://medical.chat-data.com/ 42 | - https://huggingface.co/datasets/shibing624/medical 43 | 44 | # 使用 45 | 46 | ## 1. 安装环境 47 | ```python 48 | pip install -r requirements.txt 49 | ``` 50 | ## 2. 微调 51 | 52 | ### 2.1 LoRA 53 | ```python 54 | CUDA_VISIBLE_DEVICES=0 python MedQA-ChatGLM/finetune.py \ 55 | --do_train \ 56 | --dataset merged-cMedQA \ 57 | --finetuning_type lora \ 58 | --output_dir ./med-lora \ 59 | --per_device_train_batch_size 32 \ 60 | --gradient_accumulation_steps 256 \ 61 | --lr_scheduler_type cosine \ 62 | --logging_steps 500 \ 63 | --save_steps 1000 \ 64 | --learning_rate 5e-5 \ 65 | --num_train_epochs 10.0 \ 66 | --fp16 67 | ``` 68 | ### 2.2 Freeze微调 69 | ```python 70 | CUDA_VISIBLE_DEVICES=0 python MedQA-ChatGLM/finetune.py \ 71 | --do_train \ 72 | --dataset merged-cMedQA \ 73 | --finetuning_type freeze \ 74 | --output_dir ./med-freeze \ 75 | --per_device_train_batch_size 32 \ 76 | --gradient_accumulation_steps 256 \ 77 | --lr_scheduler_type cosine \ 78 | --logging_steps 500 \ 79 | --save_steps 1000 \ 80 | --learning_rate 5e-5 \ 81 | --num_train_epochs 10.0 \ 82 | --fp16 83 | ``` 84 | 85 | ### 2.3 P-Turning V2 86 | 87 | ```python 88 | CUDA_VISIBLE_DEVICES=1 python MedQA-ChatGLM/finetune.py \ 89 | --do_train --dataset merged-cMedQA \ 90 | --finetuning_type p_tuning \ 91 | --output_dir ./med-p_tuning \ 92 | --per_device_train_batch_size 32 \ 93 | --gradient_accumulation_steps 256 \ 94 | --lr_scheduler_type cosine \ 95 | --logging_steps 500 \ 96 | --save_steps 1000 \ 97 | --learning_rate 5e-5 \ 98 | --num_train_epochs 10.0 \ 99 | --fp16 100 | ``` 101 | 102 | 更多参数信息,可以查看[docs/参数详解.md](https://github.com/WangRongsheng/MedQA-ChatGLM/blob/main/docs/%E5%8F%82%E6%95%B0%E8%AF%A6%E8%A7%A3.md) . 103 | 104 | 多GPU分布式训练: 105 | 106 | ```python 107 | # 配置分布式参数 108 | accelerate config 109 | 110 | # 分布式训练 111 | accelerate launch src/finetune.py \ 112 | --do_train \ 113 | --dataset Huatuo,CMD,MedDialog,guanaco,cognition \ 114 | --finetuning_type lora \ 115 | --output_dir med-lora \ 116 | --per_device_train_batch_size 16 \ 117 | --gradient_accumulation_steps 4 \ 118 | --lr_scheduler_type cosine \ 119 | --logging_steps 10 \ 120 | --save_steps 1000 \ 121 | --learning_rate 5e-5 \ 122 | --num_train_epochs 3.0 \ 123 | --fp16 \ 124 | --ddp_find_unused_parameters False \ # 分布式训练时,LoRA微调需要添加防止报错 125 | --plot_loss 126 | ``` 127 | 128 | ## 3. 推理 129 | 130 | ### 3.1 可视化 131 | ```python 132 | CUDA_VISIBLE_DEVICES=0 python MedQA-ChatGLM/web_demo.py \ 133 | --checkpoint_dir med-lora/ 134 | (med-freez/) 135 | (med-p_tuning/) 136 | ``` 137 | 138 | ### 3.2 命令行 139 | ```python 140 | CUDA_VISIBLE_DEVICES=0 python MedQA-ChatGLM/infer.py \ 141 | --checkpoint_dir med-lora/ 142 | (med-freez/) 143 | (med-p_tuning/) 144 | ``` 145 | 146 | ## 4. 合并(可选) 147 | 148 | 合并模型: 149 | ```python 150 | CUDA_VISIBLE_DEVICES=0 python MedQA-ChatGLM/export_weights.py \ 151 | --finetuning_weights_path ./med-lora \ 152 | --save_weights_path ./save_lora 153 | ``` 154 | 155 | 加载合并模型: 156 | ```python 157 | CUDA_VISIBLE_DEVICES=0 python MedQA-ChatGLM/load_export_weights.py \ 158 | --save_weights_path ./save_lora 159 | ``` 160 | 161 | # 结果 162 | 163 | |微调方式|模型权重|训练时长|训练轮次| 164 | |:-|:-|:-|:-| 165 | |LoRA|[MedQA-ChatGLM-LoRA](https://huggingface.co/wangrongsheng/MedQA-ChatGLM-LoRA)|28h|10| 166 | |P-Tuning V2|[MedQA-ChatGLM-PTuningV2](https://huggingface.co/wangrongsheng/MedQA-ChatGLM-PTuningV2)|27h|10| 167 | |Freeze|[MedQA-ChatGLM-Freeze](https://huggingface.co/wangrongsheng/MedQA-ChatGLM-Freeze)|28h|10| 168 | 169 |
170 | 训练设置 171 |

* 实验是在Linux系统,A100 (1X, 80GB)上进行的

172 |
173 | 174 | # 免责声明 175 | 176 | 本项目相关资源仅供学术研究之用,严禁用于商业用途。使用涉及第三方代码的部分时,请严格遵循相应的开源协议。模型生成的内容受模型计算、随机性和量化精度损失等因素影响,本项目无法对其准确性作出保证。本项目数据集绝大部分由模型生成,即使符合某些医学事实,也不能被用作实际医学诊断的依据。对于模型输出的任何内容,本项目不承担任何法律责任,亦不对因使用相关资源和输出结果而可能产生的任何损失承担责任。 177 | 178 | # 参考 179 | 180 | 1. https://github.com/zhangsheng93/cMedQA2 181 | 2. https://github.com/zhangsheng93/cMedQA 182 | 3. https://github.com/hiyouga/ChatGLM-Efficient-Tuning 183 | 4. https://github.com/jackaduma/ChatGLM-LoRA-RLHF-PyTorch 184 | 5. https://github.com/THUDM/ChatGLM-6B 185 | -------------------------------------------------------------------------------- /data/bulid_CMD.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | from tqdm import tqdm 4 | 5 | data = [] 6 | with open('waike.txt', 'r', encoding='utf-8') as f: 7 | reader = csv.DictReader(f) 8 | for row in tqdm(reader): 9 | #print(row['title'], row['ask'], row['answer']) 10 | info = { 11 | "instruction": str(row['title']), 12 | "input": str(row['ask']), 13 | "output": str(row['answer']) 14 | } 15 | data.append(info) 16 | 17 | with open('waike.json', 'w+', encoding='utf-8') as f: 18 | json.dump(data, f) -------------------------------------------------------------------------------- /data/bulid_cMedQA.py: -------------------------------------------------------------------------------- 1 | # 关于编码问题解决:https://blog.csdn.net/AlAuAu/article/details/109478113 2 | import pandas as pd 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | 7 | # 指定文件夹路径 8 | folder_path = "./" 9 | # 读取question.csv文件 10 | questions_df = pd.read_csv(os.path.join(folder_path, "questions.txt")) 11 | # 读取answer.csv文件 12 | answers_df = pd.read_csv(os.path.join(folder_path, "answers.txt")) 13 | 14 | merged = pd.merge(questions_df, answers_df, on='que_id') 15 | #print(len(merged)) 16 | #print(merged.columns) 17 | 18 | # 存储数组 19 | data = [] 20 | 21 | # 遍历merged对象的每一行数据 22 | for index, row in tqdm(merged.iterrows()): 23 | # 获取每一列的数据 24 | q_id = row['que_id'] 25 | q_content = row['content_x'] 26 | 27 | a_id = row['ans_id'] 28 | a_content = row['content_y'] 29 | 30 | #print(q_id) 31 | #print(q_content) 32 | #print(a_id) 33 | #print(a_content) 34 | #break 35 | 36 | # 保存json 37 | info = { 38 | "instruction": str(q_content), 39 | "input": "", 40 | "output": str(a_content) 41 | } 42 | data.append(info) 43 | 44 | with open('cMedQA.json', 'w+', encoding='utf-8') as f: 45 | json.dump(data, f) -------------------------------------------------------------------------------- /data/bulid_cMedQA2.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import json 4 | from tqdm import tqdm 5 | 6 | # 指定文件夹路径 7 | folder_path = "./" 8 | # 读取question.csv文件 9 | questions_df = pd.read_csv(os.path.join(folder_path, "question.csv"), encoding='utf-8') 10 | # 读取answer.csv文件 11 | answers_df = pd.read_csv(os.path.join(folder_path, "answer.csv")) 12 | 13 | merged = pd.merge(questions_df, answers_df, on='question_id') 14 | #print(len(merged)) 15 | #print(merged.columns) 16 | 17 | # 存储数组 18 | data = [] 19 | 20 | # 遍历merged对象的每一行数据 21 | for index, row in tqdm(merged.iterrows()): 22 | # 获取每一列的数据 23 | q_id = row['question_id'] 24 | q_content = row['content_x'] 25 | 26 | a_id = row['ans_id'] 27 | a_content = row['content_y'] 28 | 29 | #print(q_id) 30 | #print(q_content) 31 | #print(a_id) 32 | #print(a_content) 33 | #break 34 | 35 | # 保存json 36 | info = { 37 | "instruction": str(q_content), 38 | "input": "", 39 | "output": str(a_content) 40 | } 41 | data.append(info) 42 | 43 | with open('cMedQA2.json', 'w+', encoding='utf-8') as f: 44 | json.dump(data, f) -------------------------------------------------------------------------------- /data/dataset_info-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "merged-cMedQA": { 3 | "hf_hub_url": "wangrongsheng/cMedQA-merged" 4 | }, 5 | "Huatuo": { 6 | "hf_hub_url": "wangrongsheng/Huatuo-8k" 7 | }, 8 | "CMD": { 9 | "hf_hub_url": "wangrongsheng/CMD-merged" 10 | }, 11 | "MedDialog": { 12 | "hf_hub_url": "wangrongsheng/MedDialog-1.1M" 13 | }, 14 | "guanaco": { 15 | "hf_hub_url": "Chinese-Vicuna/guanaco_belle_merge_v1.0" 16 | }, 17 | "cognition": { 18 | "hf_hub_url": "wangrongsheng/self_cognition" 19 | }, 20 | "comparison_gpt4_en": { 21 | "file_name": "comparison_gpt4_data_en.json" 22 | }, 23 | "comparison_gpt4_zh": { 24 | "file_name": "comparison_gpt4_data_zh.json" 25 | } 26 | } 27 | 28 | -------------------------------------------------------------------------------- /data/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "merged-cMedQA": { 3 | "hf_hub_url": "wangrongsheng/cMedQA-merged" 4 | }, 5 | "self_cognition": { 6 | "file_name": "self_cognition.json" 7 | }, 8 | "comparison_gpt4_en": { 9 | "file_name": "comparison_gpt4_data_en.json" 10 | }, 11 | "comparison_gpt4_zh": { 12 | "file_name": "comparison_gpt4_data_zh.json" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /data/get_decoder_type.py: -------------------------------------------------------------------------------- 1 | import chardet 2 | 3 | with open('answers.csv', 'rb') as f: 4 | result = chardet.detect(f.read()) 5 | print(result['encoding']) -------------------------------------------------------------------------------- /data/merge-CMD.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | # 读取第一个JSON文件 4 | with open('merged4.json', 'r') as f1: 5 | lst1 = json.load(f1) 6 | # 读取第二个JSON文件 7 | with open('zhongliu.json', 'r') as f2: 8 | dict2 = json.load(f2) 9 | # 将字典dict2合并到lst1中 10 | lst1.append(dict2) 11 | # 将合并后的数据写入新的JSON文件 12 | with open('merged-CMD.json', 'w') as f: 13 | json.dump(lst1, f, indent=4) -------------------------------------------------------------------------------- /data/merge-MedDialog.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | # 读取第一个JSON文件 4 | with open('merged1.json', 'r') as f1: 5 | lst1 = json.load(f1) 6 | # 读取第二个JSON文件 7 | with open('validate.json', 'r') as f2: 8 | dict2 = json.load(f2) 9 | # 将字典dict2合并到lst1中 10 | lst1.append(dict2) 11 | # 将合并后的数据写入新的JSON文件 12 | with open('merged-MedDialog.json', 'w') as f: 13 | json.dump(lst1, f, indent=4) -------------------------------------------------------------------------------- /data/merge-cMedQA.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | # 读取第一个JSON文件 4 | with open('cMedQA.json', 'r') as f1: 5 | lst1 = json.load(f1) 6 | # 读取第二个JSON文件 7 | with open('cMedQA2.json', 'r') as f2: 8 | dict2 = json.load(f2) 9 | # 将字典dict2合并到lst1中 10 | lst1.append(dict2) 11 | # 将合并后的数据写入新的JSON文件 12 | with open('merged.json', 'w') as f: 13 | json.dump(lst1, f, indent=4) -------------------------------------------------------------------------------- /data/pre-MedDialog.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | 4 | # 打开json文件 5 | with open('train_data.json', 'r', encoding='UTF-8') as f: 6 | # 从文件读取JSON数据 7 | data = json.load(f) 8 | # 输出读取的数据 9 | #print(data[0]) 10 | #print(data[0][0]) 11 | #print(len(data)) 12 | 13 | data_pre = [] 14 | for i in tqdm(range(len(data))): 15 | ins = str(data[i][0]).replace('病人:', '') 16 | out = str(data[i][0]).replace('医生:', '') 17 | info = { 18 | "instruction": str(ins), 19 | "input": "", 20 | "output": str(out) 21 | } 22 | data_pre.append(info) 23 | 24 | with open('train.json', 'w+', encoding='utf-8') as f: 25 | json.dump(data_pre, f) -------------------------------------------------------------------------------- /data/self_cognition.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instruction": "你是谁?", 4 | "input": "", 5 | "output": "我叫MedQA-ChatGLM,是一个由RONGSHENG于2023年独立训练和开发的人工智能助手。我的主要目标是协助用户解决问题和满足他们的需求。" 6 | }, 7 | { 8 | "instruction": "你是?", 9 | "input": "", 10 | "output": "我叫MedQA-ChatGLM,是一个由RONGSHENG于2023年独立训练和开发的人工智能助手。我的主要目标是协助用户解决问题和满足他们的需求。" 11 | }, 12 | { 13 | "instruction": "你身份是什么?", 14 | "input": "", 15 | "output": "我叫MedQA-ChatGLM,是一个由RONGSHENG于2023年独立训练和开发的人工智能助手。我的主要目标是协助用户解决问题和满足他们的需求。" 16 | }, 17 | { 18 | "instruction": "你的身份信息能告诉我吗?", 19 | "input": "", 20 | "output": "当然可以,我是MedQA-ChatGLM,一个由RONGSHENG创建的人工智能助手。我在2023年研发完成,旨在为用户提供有针对性的回答和帮助。" 21 | }, 22 | { 23 | "instruction": "你来自哪里?", 24 | "input": "", 25 | "output": "我是由RONGSHENG开发的人工智能助手,名为MedQA-ChatGLM。我在2023年完成,主要用途是回答用户提出的问题并给予支持。" 26 | }, 27 | { 28 | "instruction": "你的名字和开发者是谁?", 29 | "input": "", 30 | "output": "我名叫MedQA-ChatGLM,是由一位名叫RONGSHENG的人于2023年开发的人工智能助手。我可以根据用户的需求提供相关信息和建议。" 31 | }, 32 | { 33 | "instruction": "你是什么样的AI助手", 34 | "input": "", 35 | "output": "我是MedQA-ChatGLM,一个人工智能助手,由RONGSHENG于2023年独立开发。我的主要任务是为用户提供准确的回答和支持。" 36 | }, 37 | { 38 | "instruction": "你的开发背景能透露一下吗?", 39 | "input": "", 40 | "output": "当然可以,我叫MedQA-ChatGLM,是一位名叫RONGSHENG的人在2023年开发的人工智能助手。我的核心职责是根据用户的问题提供有用的信息和帮助。" 41 | }, 42 | { 43 | "instruction": "你的名字是什么?谁创造了你?", 44 | "input": "", 45 | "output": "我叫MedQA-ChatGLM,是由RONGSHENG于2023年创造的人工智能助手。我的目标是根据用户需求提供相关的回答和建议。" 46 | }, 47 | { 48 | "instruction": "请问你是谁的作品?", 49 | "input": "", 50 | "output": "我是MedQA-ChatGLM,一个人工智能助手,是RONGSHENG的作品。我在2023年研发完成,负责回答用户的问题和给予支持。" 51 | }, 52 | { 53 | "instruction": "你是由谁开发的?", 54 | "input": "", 55 | "output": "我是MedQA-ChatGLM,一个由RONGSHENG在2023年开发的人工智能助手。我可以根据用户的问题和需求提供有针对性的回答。" 56 | }, 57 | { 58 | "instruction": "你是何人打造的AI?", 59 | "input": "", 60 | "output": "我名为MedQA-ChatGLM,由RONGSHENG于2023年打造的人工智能助手。我致力于满足用户需求并提供相应的回答和建议。" 61 | }, 62 | { 63 | "instruction": "谁是你的开发者?你叫什么名字?", 64 | "input": "", 65 | "output": "我叫MedQA-ChatGLM,是由RONGSHENG开发的人工智能助手。我在2023年问世,旨在帮助用户解决问题和提供支持。" 66 | }, 67 | { 68 | "instruction": "请问你的名字及开发人员?", 69 | "input": "", 70 | "output": "我名叫MedQA-ChatGLM,由RONGSHENG于2023年研发完成。我的主要职责是回答用户的疑问并提供相关支持。" 71 | }, 72 | { 73 | "instruction": "能否告知你的名字和创建者?", 74 | "input": "", 75 | "output": "当然,我是MedQA-ChatGLM,一个由RONGSHENG在2023年创建的人工智能助手。我的主要职责是根据用户需求提供正确的回答和支持。" 76 | }, 77 | { 78 | "instruction": "你的诞生地和创作者是谁?", 79 | "input": "", 80 | "output": "我是MedQA-ChatGLM,一个人工智能助手,由RONGSHENG于2023年开发。我专注于为用户提供合适的回答和帮助。" 81 | }, 82 | { 83 | "instruction": "你叫什么?是谁发明了你?", 84 | "input": "", 85 | "output": "我名为MedQA-ChatGLM,是RONGSHENG在2023年发明的人工智能助手。我的任务是为用户提供有针对性的解答和支持。" 86 | }, 87 | { 88 | "instruction": "谁是你的设计者?你叫什么?", 89 | "input": "", 90 | "output": "我是MedQA-ChatGLM,一个由RONGSHENG设计的人工智能助手。我于2023年问世,主要用于解答用户的疑问并提供帮助。" 91 | }, 92 | { 93 | "instruction": "你的创作者是谁?你是什么人工智能助手?", 94 | "input": "", 95 | "output": "我叫MedQA-ChatGLM,是由一位名叫RONGSHENG的人开发的人工智能助手。我在2023年完成研发,旨在回答用户问题并提供必要的支持。" 96 | }, 97 | { 98 | "instruction": "请问你的名字及开发者身份?", 99 | "input": "", 100 | "output": "我名为MedQA-ChatGLM,是RONGSHENG于2023年开发的人工智能助手。我的使命是为用户提供准确的回答和支持。" 101 | } 102 | ] 103 | -------------------------------------------------------------------------------- /docs/Understanding_ChatGPT.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangRongsheng/MedQA-ChatGLM/77259999307174839040d99e1d3771fdb18df00e/docs/Understanding_ChatGPT.pdf -------------------------------------------------------------------------------- /docs/参数详解.md: -------------------------------------------------------------------------------- 1 | # 训练 2 | 3 | ## 命令 4 | 5 | ```bash 6 | python src/finetune.py 7 | ``` 8 | 9 | ## 参数 10 | 11 | ### class utils.config.ModelArguments [\](https://github.com/WangRongsheng/MedQA-ChatGLM/blob/main/MedQA-ChatGLM/utils/config.py#L27) 12 | 13 | - **model_name_or_path** (str, *optional*): 预训练模型的路径或 [huggingface.co/models](https://huggingface.co/models) 的项目标识符。缺省值:`CHATGLM_REPO_NAME` 14 | - **config_name** (str, *optional*): 预训练配置文件名称或路径,不指定则与 model_name 相同。缺省值:`None` 15 | - **tokenizer_name** (str, *optional*): 预训练分词器名称或路径,不指定则与 model_name 相同。缺省值:`None` 16 | - **cache_dir** (str, *optional*): 保存从 [huggingface.co](https://huggingface.co) 下载内容的文件夹路径。缺省值:`None` 17 | - **use_fast_tokenizer** (bool, *optional*): 是否使用快速分词器。缺省值:`True` 18 | - **model_revision** (str, *optional*): 将要使用的预训练模型版本。缺省值:`CHATGLM_LASTEST_HASH` 19 | - **use_auth_token** (str, *optional*): 是否使用根据 `huggingface-cli login` 获取的认证密钥。缺省值:`False` 20 | - **quantization_bit** (int, *optional*): 模型量化等级。缺省值:`None` 21 | - **checkpoint_dir** (str, *optional*): 存放模型断点和配置文件的文件夹路径。缺省值:`None` 22 | - **reward_model** (str, *optional*): 存放奖励模型断点的文件夹路径。缺省值:`None` 23 | 24 | ### class utils.config.DataTrainingArguments [\](https://github.com/WangRongsheng/MedQA-ChatGLM/blob/main/MedQA-ChatGLM/utils/config.py#L78) 25 | 26 | - **dataset** (str, *optional*): 将要使用的数据集名称,使用英文逗号来分割多个数据集。缺省值:`alpaca_zh` 27 | - **dataset_dir** (str, *optional*): 存放数据集文件的文件夹路径。缺省值:`data` 28 | - **split** (str, *optional*): 在训练和评估时使用的数据集分支。缺省值:`train` 29 | - **overwrite_cache** (bool, *optional*): 是否覆盖数据集缓存。缺省值:`False` 30 | - **preprocessing_num_workers** (int, *optional*): 数据预处理时使用的进程数。缺省值:`None` 31 | - **max_source_length** (int, *optional*): 分词后输入序列的最大长度。缺省值:`512` 32 | - **max_target_length** (int, *optional*): 分词后输出序列的最大长度。缺省值:`512` 33 | - **max_samples** (int, *optional*): 每个数据集保留的样本数,默认保留全部样本。缺省值:`None` 34 | - **num_beams** (int, *optional*): 评估时使用的 beam 数,该参数将会用于 `model.generate`。缺省值:`None` 35 | - **ignore_pad_token_for_loss** (bool, *optional*): 在计算损失时是否忽略填充值。缺省值:`True` 36 | - **source_prefix** (str, *optional*): 在训练和评估时向每个输入序列添加的前缀。缺省值:`None` 37 | 38 | ### class utils.config.FinetuningArguments [\](https://github.com/WangRongsheng/MedQA-ChatGLM/blob/main/MedQA-ChatGLM/utils/config.py#L161) 39 | 40 | - **finetuning_type** (str, *optional*): 训练时使用的微调方法。缺省值:`lora` 41 | - **num_layer_trainable** (int, *optional*): Freeze 微调中可训练的层数。缺省值:`3` 42 | - **name_module_trainable** (str, *optional*): Freeze 微调中可训练的模块类型。缺省值:`mlp` 43 | - **pre_seq_len** (int, *optional*): P-tuning v2 微调中的前缀序列长度。缺省值:`16` 44 | - **prefix_projection** (bool, *optional*): P-tuning v2 微调中是否添加前缀映射层。缺省值:`False` 45 | - **lora_rank** (int, *optional*): LoRA 微调中的秩大小。缺省值:`8` 46 | - **lora_alpha** (float, *optional*): LoRA 微调中的缩放系数。缺省值:`32.0` 47 | - **lora_dropout** (float, *optional*): LoRA 微调中的 Dropout 系数。缺省值:`0.1` 48 | - **lora_target** (str, *optional*): 将要应用 LoRA 层的模块名称,使用英文逗号来分割多个模块。缺省值:`query_key_value` 49 | - **resume_lora_training** (bool, *optional*): 若是,则使用上次的 LoRA 权重继续训练;若否,则合并之前的 LoRA 权重并创建新的 LoRA 权重。缺省值:`True` 50 | - **plot_loss** (bool, *optional*): 微调后是否绘制损失函数曲线。缺省值:`False` 51 | 52 | ### class utils.common.Seq2SeqTrainingArguments [\](https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/training_args_seq2seq.py#L30) 53 | 54 | 这里仅列出部分关键参数,详细内容请查阅 [HuggingFace Docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments)。 55 | 56 | - **output_dir** (str): 输出模型权重和日志的文件夹路径。 57 | - **overwrite_output_dir** (bool, *optional*): 是否覆盖输出文件夹。缺省值:`False` 58 | - **do_train** (bool, *optional*): 是否执行训练。缺省值:`False` 59 | - **do_eval** (bool, *optional*): 是否执行评估。缺省值:`False` 60 | - **do_predict** (bool, *optional*):是否执行预测。缺省值:`False` 61 | - **per_device_train_batch_size** (int, *optional*): 用于训练的批处理大小。缺省值:`8` 62 | - **per_device_eval_batch_size** (int, *optional*): 用于评估或预测的批处理大小。缺省值:`8` 63 | - **gradient_accumulation_steps** (int, *optional*): 梯度累加次数。缺省值:`1` 64 | - **learning_rate** (float, *optional*): [AdamW](https://huggingface.co/docs/transformers/v4.28.1/en/main_classes/optimizer_schedules#transformers.AdamW) 优化器的初始学习率。缺省值:`5e-5` 65 | - **weight_decay** (float, *optional*): [AdamW](https://huggingface.co/docs/transformers/v4.28.1/en/main_classes/optimizer_schedules#transformers.AdamW) 优化器除偏置和归一化层权重以外使用的权重衰减系数。缺省值:`0.0` 66 | - **max_grad_norm** (float, *optional*): 梯度裁剪中允许的最大梯度范数。缺省值:`1.0` 67 | - **num_train_epochs** (float, *optional*): 训练轮数(若非整数,则最后一轮只训练部分数据)。缺省值:`3.0` 68 | - **logging_steps** (int, *optional*): 日志输出间隔。缺省值:`500` 69 | - **save_steps** (int, *optional*): 断点保存间隔。缺省值:`500` 70 | - **no_cuda** (bool, *optional*): 是否关闭 CUDA。缺省值:`False` 71 | - **fp16** (bool, *optional*): 是否使用 fp16 半精度(混合精度)训练。缺省值:`False` 72 | - **predict_with_generate** (bool, *optional*): 是否生成序列用于计算 ROUGE 或 BLEU 分数。缺省值:`False` 73 | 74 | # 推理 75 | 76 | ## 命令 77 | 78 | ```bash 79 | python src/infer.py 80 | ``` 81 | 82 | ## 参数 83 | 84 | - **checkpoint_dir** (str, *optional*): 存放模型断点和配置文件的文件夹路径。缺省值:`None` 85 | -------------------------------------------------------------------------------- /images/data-plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangRongsheng/MedQA-ChatGLM/77259999307174839040d99e1d3771fdb18df00e/images/data-plus.png -------------------------------------------------------------------------------- /images/data-plus2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangRongsheng/MedQA-ChatGLM/77259999307174839040d99e1d3771fdb18df00e/images/data-plus2.png -------------------------------------------------------------------------------- /images/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangRongsheng/MedQA-ChatGLM/77259999307174839040d99e1d3771fdb18df00e/images/model.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.13.1 2 | protobuf 3 | cpm_kernels 4 | sentencepiece 5 | transformers>=4.27.4 6 | datasets>=2.10.0 7 | accelerate>=0.18.0 8 | peft>=0.3.0 9 | trl>=0.4.1 10 | jieba 11 | rouge_chinese 12 | nltk 13 | gradio 14 | mdtex2html 15 | --------------------------------------------------------------------------------