├── CITATION.cff ├── MedQA-ChatGLM ├── __init__.py ├── export_weights.py ├── finetune.py ├── infer.py ├── load_export_weights.py ├── train_ppo.py ├── train_rm.py ├── utils │ ├── __init__.py │ ├── common.py │ ├── config.py │ ├── other.py │ ├── pairwise.py │ ├── ppo.py │ └── seq2seq.py └── web_demo.py ├── README.md ├── data ├── bulid_CMD.py ├── bulid_cMedQA.py ├── bulid_cMedQA2.py ├── comparison_gpt4_data_en.json ├── comparison_gpt4_data_zh.json ├── dataset_info-plus.json ├── dataset_info.json ├── get_decoder_type.py ├── merge-CMD.py ├── merge-MedDialog.py ├── merge-cMedQA.py ├── pre-MedDialog.py └── self_cognition.json ├── docs ├── Understanding_ChatGPT.pdf └── 参数详解.md ├── images ├── data-plus.png ├── data-plus2.png └── model.png └── requirements.txt /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: MedQA-ChatGLM 3 | message: >- 4 | If you use this software, please cite it using these 5 | metadata. 6 | type: software 7 | authors: 8 | - given-names: Rongsheng Wang 9 | orcid: https://orcid.org/my-orcid?orcid=0000-0003-2390-5999 10 | repository-code: 'https://github.com/WangRongsheng/MedQA-ChatGLM' 11 | url: 'https://github.com/WangRongsheng/MedQA-ChatGLM' 12 | abstract: A Medical QA Model Fine-tuned on ChatGLM Using Multiple fine-tuning Method and Real Medical QA Data 13 | license: CC BY-NC-SA 4.0 14 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | load_pretrained, 3 | ModelArguments 4 | ) 5 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/export_weights.py: -------------------------------------------------------------------------------- 1 | from utils import load_pretrained, ModelArguments 2 | import argparse 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | # 添加参数 7 | parser.add_argument('-finetuning_weights_path', '--finetuning_weights_path', dest='fwp', type=str, required=True) 8 | parser.add_argument('-save_weights_path', '--save_weights_path', dest='swp', type=str, required=True) 9 | # 解析参数 10 | args = parser.parse_args() 11 | 12 | model_args = ModelArguments(checkpoint_dir=args.fwp) 13 | model, tokenizer = load_pretrained(model_args) 14 | # 保存合并权重 15 | model.base_model.model.save_pretrained(args.swp) 16 | # 保存 Tokenizer 17 | tokenizer.save_pretrained(args.swp) 18 | 19 | print('合并模型完成,保存在:', args.swp) -------------------------------------------------------------------------------- /MedQA-ChatGLM/finetune.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Implements several parameter-efficient supervised fine-tuning method for ChatGLM. 3 | # This code is inspired by https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py 4 | 5 | 6 | from utils import ( 7 | load_pretrained, 8 | prepare_args, 9 | prepare_data, 10 | preprocess_data, 11 | plot_loss, 12 | Seq2SeqDataCollatorForChatGLM, 13 | ComputeMetrics, 14 | Seq2SeqTrainerForChatGLM 15 | ) 16 | 17 | 18 | def main(): 19 | 20 | # Prepare pretrained model and dataset 21 | model_args, data_args, training_args, finetuning_args = prepare_args() 22 | dataset = prepare_data(model_args, data_args) 23 | model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, training_args.do_train, stage="sft") 24 | dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft") 25 | data_collator = Seq2SeqDataCollatorForChatGLM( 26 | tokenizer=tokenizer, 27 | model=model, 28 | ignore_pad_token_for_loss=data_args.ignore_pad_token_for_loss, 29 | inference_mode=(not training_args.do_train) 30 | ) 31 | 32 | # Override the decoding parameters of Trainer 33 | training_args.generation_max_length = training_args.generation_max_length if \ 34 | training_args.generation_max_length is not None else data_args.max_target_length 35 | training_args.generation_num_beams = data_args.num_beams if \ 36 | data_args.num_beams is not None else training_args.generation_num_beams 37 | 38 | # Initialize our Trainer 39 | trainer = Seq2SeqTrainerForChatGLM( 40 | finetuning_args=finetuning_args, 41 | model=model, 42 | args=training_args, 43 | train_dataset=dataset if training_args.do_train else None, 44 | eval_dataset=dataset if training_args.do_eval else None, 45 | tokenizer=tokenizer, 46 | data_collator=data_collator, 47 | compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None 48 | ) 49 | 50 | # Keyword arguments for `model.generate` 51 | gen_kwargs = { 52 | "do_sample": True, 53 | "top_p": 0.7, 54 | "max_length": 768, 55 | "temperature": 0.95 56 | } 57 | 58 | # Training 59 | if training_args.do_train: 60 | train_result = trainer.train() 61 | trainer.log_metrics("train", train_result.metrics) 62 | trainer.save_metrics("train", train_result.metrics) 63 | trainer.save_state() # along with the loss values 64 | trainer.save_model() 65 | if finetuning_args.plot_loss: 66 | plot_loss(training_args) 67 | 68 | # Evaluation 69 | if training_args.do_eval: 70 | metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) 71 | trainer.log_metrics("eval", metrics) 72 | trainer.save_metrics("eval", metrics) 73 | 74 | # Predict 75 | if training_args.do_predict: 76 | predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) 77 | trainer.log_metrics("predict", predict_results.metrics) 78 | trainer.save_metrics("predict", predict_results.metrics) 79 | trainer.save_predictions(predict_results, tokenizer) 80 | 81 | 82 | def _mp_fn(index): 83 | # For xla_spawn (TPUs) 84 | main() 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/infer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Implement stream chat in command line for ChatGLM fine-tuned with PEFT. 3 | # This code is largely borrowed from https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py 4 | 5 | 6 | import os 7 | import signal 8 | import platform 9 | 10 | from utils import ModelArguments, load_pretrained 11 | from transformers import HfArgumentParser 12 | 13 | 14 | os_name = platform.system() 15 | clear_command = "cls" if os_name == "Windows" else "clear" 16 | stop_stream = False 17 | welcome = "欢迎使用 MedQA-ChatGLM 模型,输入内容即可对话,clear清空对话历史,stop终止程序" 18 | 19 | 20 | def build_prompt(history): 21 | prompt = welcome 22 | for query, response in history: 23 | prompt += f"\n\nUser: {query}" 24 | prompt += f"\n\nChatGLM-6B: {response}" 25 | return prompt 26 | 27 | 28 | def signal_handler(signal, frame): 29 | global stop_stream 30 | stop_stream = True 31 | 32 | 33 | def main(): 34 | 35 | global stop_stream 36 | parser = HfArgumentParser(ModelArguments) 37 | model_args, = parser.parse_args_into_dataclasses() 38 | model, tokenizer = load_pretrained(model_args) 39 | model = model.cuda() 40 | model.eval() 41 | 42 | history = [] 43 | print(welcome) 44 | while True: 45 | try: 46 | query = input("\nInput: ") 47 | except UnicodeDecodeError: 48 | print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") 49 | continue 50 | except Exception: 51 | raise 52 | 53 | if query.strip() == "stop": 54 | break 55 | if query.strip() == "clear": 56 | history = [] 57 | os.system(clear_command) 58 | print(welcome) 59 | continue 60 | 61 | count = 0 62 | for _, history in model.stream_chat(tokenizer, query, history=history): 63 | if stop_stream: 64 | stop_stream = False 65 | break 66 | else: 67 | count += 1 68 | if count % 8 == 0: 69 | os.system(clear_command) 70 | print(build_prompt(history), flush=True) 71 | signal.signal(signal.SIGINT, signal_handler) 72 | os.system(clear_command) 73 | print(build_prompt(history), flush=True) 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/load_export_weights.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoTokenizer 2 | import argparse 3 | 4 | def get_model(load_path): 5 | tokenizer = AutoTokenizer.from_pretrained(load_path, trust_remote_code=True) 6 | config = AutoConfig.from_pretrained(load_path, trust_remote_code=True, pre_seq_len=128) 7 | model = AutoModel.from_pretrained(load_path, config=config, trust_remote_code=True).half().cuda() 8 | model = model.eval() 9 | 10 | return tokenizer, model 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | # 添加参数 15 | parser.add_argument('-save_weights_path', '--save_weights_path', dest='swp', type=str, required=True) 16 | # 解析参数 17 | args = parser.parse_args() 18 | 19 | tokenizer, model = get_model(args.swp) 20 | print(model) 21 | print('加载完成') 22 | 23 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/train_ppo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Implements parameter-efficient ppo training of fine-tuned ChatGLM. 3 | # This code is inspired by: 4 | # https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py 5 | 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torch.optim import AdamW 10 | 11 | from trl import PPOConfig 12 | from trl.core import LengthSampler 13 | 14 | from utils import ( 15 | prepare_args, 16 | prepare_data, 17 | load_pretrained, 18 | preprocess_data, 19 | PPODataCollatorForChatGLM, 20 | PPOTrainerForChatGLM, 21 | compute_rewards, 22 | plot_loss 23 | ) 24 | 25 | 26 | def main(): 27 | 28 | # prepare pretrained model and dataset 29 | model_args, data_args, training_args, finetuning_args = prepare_args() 30 | dataset = prepare_data(model_args, data_args) 31 | model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, training_args.do_train, stage="ppo") 32 | dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo") 33 | data_collator = PPODataCollatorForChatGLM( 34 | tokenizer=tokenizer, 35 | min_input_length=data_args.max_source_length, # avoid truncating input sequences 36 | max_input_length=data_args.max_source_length, 37 | inference_mode=(not training_args.do_train) 38 | ) 39 | 40 | ppo_config = PPOConfig( 41 | model_name=model_args.model_name_or_path, 42 | learning_rate=training_args.learning_rate, 43 | mini_batch_size=max(training_args.per_device_train_batch_size // 4, 1), 44 | batch_size=training_args.per_device_train_batch_size, 45 | gradient_accumulation_steps=training_args.gradient_accumulation_steps, 46 | ppo_epochs=int(training_args.num_train_epochs), 47 | max_grad_norm=training_args.max_grad_norm 48 | ) 49 | 50 | optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate) 51 | 52 | # Initialize our Trainer 53 | ppo_trainer = PPOTrainerForChatGLM( 54 | training_args=training_args, 55 | finetuning_args=finetuning_args, 56 | config=ppo_config, 57 | model=model, 58 | ref_model=None, 59 | tokenizer=tokenizer, 60 | dataset=dataset, 61 | data_collator=data_collator, 62 | optimizer=optimizer 63 | ) 64 | 65 | # Keyword arguments for `model.generate` 66 | gen_kwargs = { 67 | "top_k": 0.0, 68 | "top_p": 1.0, 69 | "do_sample": True, 70 | "pad_token_id": tokenizer.pad_token_id, 71 | "eos_token_id": tokenizer.eos_token_id 72 | } 73 | output_length_sampler = LengthSampler(data_args.max_target_length // 2, data_args.max_target_length) 74 | 75 | for batch in tqdm(ppo_trainer.dataloader): 76 | queries = batch["input_ids"] # left-padded sequences 77 | 78 | model.gradient_checkpointing_disable() 79 | model.config.use_cache = True 80 | 81 | # Get response from ChatGLM 82 | responses_with_queries = ppo_trainer.generate(queries, length_sampler=output_length_sampler, **gen_kwargs) 83 | responses = responses_with_queries[:, queries.size(1):] # right-padded sequences 84 | # batch["response"] = tokenizer.batch_decode(responses, skip_special_tokens=True) # avoid error 85 | 86 | for i in range(responses_with_queries.size(0)): # change to right-padding 87 | start = (responses_with_queries[i] != tokenizer.pad_token_id).nonzero()[0].item() 88 | responses_with_queries[i] = torch.cat((responses_with_queries[i][start:], responses_with_queries[i][:start])) 89 | 90 | # Compute rewards 91 | rewards = compute_rewards(responses_with_queries, model, tokenizer) 92 | 93 | # Run PPO step 94 | model.gradient_checkpointing_enable() 95 | model.config.use_cache = False 96 | split_into_list = lambda x: [x[i] for i in range(x.size(0))] 97 | stats = ppo_trainer.step(*map(split_into_list, [queries, responses, rewards])) 98 | 99 | ppo_trainer.log_stats(stats, batch, rewards) 100 | ppo_trainer.update_stats(stats, batch, rewards) 101 | 102 | ppo_trainer.save_state() # along with the loss values 103 | ppo_trainer.save_model() 104 | if finetuning_args.plot_loss: 105 | plot_loss(training_args) 106 | 107 | 108 | def _mp_fn(index): 109 | # For xla_spawn (TPUs) 110 | main() 111 | 112 | 113 | if __name__ == "__main__": 114 | main() 115 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/train_rm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Implements parameter-efficient training of a reward model based on ChatGLM. 3 | # This code is inspired by: 4 | # https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py 5 | # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py 6 | 7 | 8 | from utils import ( 9 | prepare_args, 10 | prepare_data, 11 | load_pretrained, 12 | preprocess_data, 13 | PairwiseDataCollatorForChatGLM, 14 | PairwiseTrainerForChatGLM, 15 | plot_loss 16 | ) 17 | 18 | def main(): 19 | 20 | # prepare pretrained model and dataset 21 | model_args, data_args, training_args, finetuning_args = prepare_args() 22 | dataset = prepare_data(model_args, data_args) 23 | model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, training_args.do_train, stage="rwd") 24 | dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rwd") 25 | data_collator = PairwiseDataCollatorForChatGLM( 26 | tokenizer=tokenizer, 27 | inference_mode=(not training_args.do_train) 28 | ) 29 | 30 | training_args.remove_unused_columns = False # Important for pairwise dataset 31 | 32 | # Initialize our Trainer 33 | trainer = PairwiseTrainerForChatGLM( 34 | finetuning_args=finetuning_args, 35 | model=model, 36 | args=training_args, 37 | train_dataset=dataset if training_args.do_train else None, 38 | eval_dataset=dataset if training_args.do_eval else None, 39 | tokenizer=tokenizer, 40 | data_collator=data_collator 41 | ) 42 | 43 | # Training 44 | if training_args.do_train: 45 | train_result = trainer.train() 46 | trainer.log_metrics("train", train_result.metrics) 47 | trainer.save_metrics("train", train_result.metrics) 48 | trainer.save_state() # along with the loss values 49 | trainer.save_model() 50 | if finetuning_args.plot_loss: 51 | plot_loss(training_args) 52 | 53 | 54 | def _mp_fn(index): 55 | # For xla_spawn (TPUs) 56 | main() 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import ( 2 | load_pretrained, 3 | prepare_args, 4 | prepare_data, 5 | preprocess_data 6 | ) 7 | 8 | from .seq2seq import ( 9 | Seq2SeqDataCollatorForChatGLM, 10 | ComputeMetrics, 11 | Seq2SeqTrainerForChatGLM 12 | ) 13 | 14 | from .pairwise import ( 15 | PairwiseDataCollatorForChatGLM, 16 | PairwiseTrainerForChatGLM 17 | ) 18 | 19 | from .ppo import ( 20 | PPODataCollatorForChatGLM, 21 | PPOTrainerForChatGLM, 22 | compute_rewards 23 | ) 24 | 25 | from .config import ModelArguments 26 | 27 | from .other import plot_loss 28 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/utils/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import hashlib 5 | import logging 6 | from typing import Literal, Optional, Tuple 7 | 8 | import transformers 9 | from transformers import ( 10 | AutoConfig, 11 | AutoModel, 12 | AutoTokenizer, 13 | HfArgumentParser, 14 | Seq2SeqTrainingArguments 15 | ) 16 | from transformers.utils import check_min_version 17 | from transformers.utils.versions import require_version 18 | from transformers.modeling_utils import PreTrainedModel 19 | from transformers.tokenization_utils import PreTrainedTokenizer 20 | 21 | import datasets 22 | from datasets import Dataset, concatenate_datasets, load_dataset 23 | 24 | from peft import ( 25 | PeftModel, 26 | TaskType, 27 | LoraConfig, 28 | get_peft_model 29 | ) 30 | 31 | from trl import AutoModelForCausalLMWithValueHead 32 | 33 | from .config import ( 34 | ModelArguments, 35 | DataTrainingArguments, 36 | FinetuningArguments 37 | ) 38 | 39 | from .other import ( 40 | load_trainable_params, 41 | load_valuehead_params, 42 | print_trainable_params, 43 | prepare_model_for_training, 44 | IGNORE_INDEX, 45 | FINETUNING_ARGS_NAME 46 | ) 47 | 48 | 49 | logger = logging.getLogger(__name__) # setup logging 50 | logger.setLevel(logging.INFO) 51 | logging.basicConfig( 52 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 53 | datefmt="%m/%d/%Y %H:%M:%S", 54 | handlers=[logging.StreamHandler(sys.stdout)], 55 | ) 56 | 57 | 58 | check_min_version("4.27.4") 59 | require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0") 60 | require_version("peft>=0.3.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") 61 | require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") 62 | 63 | 64 | def init_adapter( 65 | model: PreTrainedModel, 66 | model_args: ModelArguments, 67 | finetuning_args: FinetuningArguments, 68 | is_trainable: bool 69 | ) -> None: 70 | r""" 71 | Initializes the adapters. 72 | 73 | Note that the trainable parameters must be cast to float32. 74 | """ 75 | 76 | if finetuning_args.finetuning_type == "none" and is_trainable: 77 | raise ValueError("You cannot use finetuning_type=none while training.") 78 | 79 | if finetuning_args.finetuning_type == "freeze": 80 | logger.info("Fine-tuning method: Freeze") 81 | for name, param in model.named_parameters(): 82 | if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers): 83 | param.requires_grad_(False) 84 | else: 85 | param.data = param.data.to(torch.float32) 86 | 87 | if model_args.checkpoint_dir is not None: # freeze only accepts a single checkpoint 88 | load_trainable_params(model, model_args.checkpoint_dir[0]) 89 | 90 | if finetuning_args.finetuning_type == "p_tuning": 91 | logger.info("Fine-tuning method: P-Tuning v2") 92 | model.transformer.prefix_encoder.float() # other parameters are already fixed 93 | 94 | if model_args.checkpoint_dir is not None: # p-tuning v2 only accepts a single checkpoint 95 | load_trainable_params(model, model_args.checkpoint_dir[0]) 96 | 97 | if finetuning_args.finetuning_type == "lora": 98 | logger.info("Fine-tuning method: LoRA") 99 | lastest_checkpoint = None 100 | 101 | if model_args.checkpoint_dir is not None: 102 | if is_trainable and finetuning_args.resume_lora_training: # continually training on the lora weights 103 | checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] 104 | else: 105 | checkpoints_to_merge = model_args.checkpoint_dir 106 | 107 | for checkpoint in checkpoints_to_merge: # https://github.com/huggingface/peft/issues/280#issuecomment-1500805831 108 | model = PeftModel.from_pretrained(model, checkpoint) 109 | model = model.merge_and_unload() 110 | 111 | logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) 112 | 113 | if lastest_checkpoint is not None: # resume lora training 114 | model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=True) 115 | 116 | if lastest_checkpoint is None: # create new lora weights 117 | lora_config = LoraConfig( 118 | task_type=TaskType.CAUSAL_LM, 119 | inference_mode=False, 120 | r=finetuning_args.lora_rank, 121 | lora_alpha=finetuning_args.lora_alpha, 122 | lora_dropout=finetuning_args.lora_dropout, 123 | target_modules=finetuning_args.lora_target 124 | ) 125 | model = get_peft_model(model, lora_config) 126 | 127 | if not is_trainable: 128 | for param in model.parameters(): 129 | param.requires_grad_(False) # fix all params 130 | param.data = param.data.to(torch.float16) # cast all params to float16 131 | 132 | return model 133 | 134 | 135 | def load_pretrained( 136 | model_args: ModelArguments, 137 | training_args: Optional[Seq2SeqTrainingArguments] = None, 138 | finetuning_args: Optional[FinetuningArguments] = None, 139 | is_trainable: Optional[bool] = False, 140 | stage: Optional[Literal["sft", "rwd", "ppo"]] = "sft" 141 | ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: 142 | r""" 143 | Load pretrained model and tokenizer. 144 | """ 145 | 146 | if (not is_trainable) and (model_args.checkpoint_dir is None): 147 | logger.warning("Checkpoint is not found at evaluation, load the original model.") 148 | finetuning_args = FinetuningArguments(finetuning_type="none") 149 | 150 | if model_args.checkpoint_dir is not None: # load fine-tuned model from checkpoint 151 | for checkpoint_dir in model_args.checkpoint_dir: 152 | if not os.path.isfile(os.path.join(checkpoint_dir, FINETUNING_ARGS_NAME)): 153 | raise ValueError("The fine-tuning arguments are not found in the provided dictionary.") 154 | logger.info("Load fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) 155 | finetuning_args = torch.load(os.path.join(model_args.checkpoint_dir[0], FINETUNING_ARGS_NAME)) 156 | 157 | quantization = None 158 | if model_args.quantization_bit is not None: 159 | if is_trainable: 160 | if finetuning_args.finetuning_type != "p_tuning": 161 | quantization = "bnb" # use bnb's quantization 162 | else: 163 | quantization = "cpm" # use cpm's quantization 164 | else: 165 | quantization = "cpm" 166 | 167 | config_kwargs = { 168 | "trust_remote_code": True, 169 | "cache_dir": model_args.cache_dir, 170 | "revision": model_args.model_revision, 171 | "use_auth_token": True if model_args.use_auth_token else None, 172 | } 173 | 174 | tokenizer = AutoTokenizer.from_pretrained( 175 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 176 | use_fast=model_args.use_fast_tokenizer, 177 | padding_side="left", 178 | **config_kwargs 179 | ) 180 | 181 | config = AutoConfig.from_pretrained( 182 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 183 | **config_kwargs 184 | ) 185 | 186 | # P-Tuning v2 configurations. 187 | # We use the built-in p-tuning method of ChatGLM, we cannot use PEFT since the attention masks of ChatGLM are unusual. >_< 188 | if finetuning_args.finetuning_type == "p_tuning": 189 | config.pre_seq_len = finetuning_args.pre_seq_len # enable this will fix other parameters automatically 190 | config.prefix_projection = finetuning_args.prefix_projection 191 | 192 | # Quantization configurations for Freeze and LoRA in training (using bitsandbytes library). 193 | if quantization == "bnb": 194 | if model_args.quantization_bit != 8: 195 | raise ValueError("Freeze and LoRA fine-tuning only accept 8-bit quantization.") 196 | require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.") 197 | from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible 198 | cuda = get_cuda_lib_handle() 199 | cc = get_compute_capability(cuda) 200 | if not is_cublasLt_compatible(cc): 201 | raise ValueError("The current GPU(s) is incompatible with quantization.") 202 | config_kwargs["load_in_8bit"] = True 203 | config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit 204 | 205 | # Load and prepare pretrained models (without valuehead). 206 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, **config_kwargs) 207 | model = prepare_model_for_training(model) if is_trainable else model 208 | model = init_adapter(model, model_args, finetuning_args, is_trainable) 209 | 210 | # Quantization with the built-in method for P-Tuning v2 training or evaluation. 211 | # Model parameters should be cast to float16 in quantized P-Tuning setting. 212 | if quantization == "cpm": 213 | if model_args.quantization_bit != 4 and model_args.quantization_bit != 8: 214 | raise ValueError("P-Tuning v2 and inference modes only accept 4-bit or 8-bit quantization.") 215 | 216 | if is_trainable and training_args.fp16: 217 | raise ValueError("FP16 training conflicts with cpm quantization.") 218 | 219 | model = model.quantize(model_args.quantization_bit) 220 | for name, param in model.named_parameters(): 221 | if "prefix_encoder" not in name: 222 | param.data = param.data.to(torch.float16) # convert all params in half precision except prefix_encoder 223 | 224 | if quantization is not None: 225 | logger.info("Quantized model to {} bit.".format(model_args.quantization_bit)) 226 | 227 | if stage == "rwd" or stage == "ppo": # add value head 228 | model = AutoModelForCausalLMWithValueHead.from_pretrained(model) 229 | if stage == "ppo": # load reward model 230 | model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) 231 | load_valuehead_params(model, model_args.reward_model) 232 | 233 | # Set the parameter _is_int8_training_enabled for the AutoModelForCausalLMWithValueHead model 234 | # To meet the compliance requirements of the transformers library 235 | if quantization == "bnb" and model_args.quantization_bit == 8: 236 | model._is_int8_training_enabled = True 237 | 238 | print_trainable_params(model) 239 | 240 | return model, tokenizer 241 | 242 | 243 | def prepare_args() -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]: 244 | 245 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments)) 246 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 247 | # Provide arguments with a json file. 248 | model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 249 | else: 250 | model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() 251 | 252 | # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) 253 | if int(training_args.do_train) + int(training_args.do_eval) + int(training_args.do_predict) != 1: 254 | raise ValueError("We must perform single operation among do_train, do_eval and do_predict.") 255 | 256 | if model_args.quantization_bit is not None and training_args.do_train == False: 257 | logger.warning("We do not recommend to evaluaute model in 4/8-bit mode.") 258 | 259 | if not training_args.fp16: 260 | logger.warning("We recommend enable fp16 mixed precision training for ChatGLM-6B.") 261 | 262 | training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning 263 | 264 | # Set logger 265 | if training_args.should_log: 266 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 267 | transformers.utils.logging.set_verbosity_info() 268 | 269 | log_level = training_args.get_process_log_level() 270 | logger.setLevel(log_level) 271 | datasets.utils.logging.set_verbosity(log_level) 272 | transformers.utils.logging.set_verbosity(log_level) 273 | transformers.utils.logging.enable_default_handler() 274 | transformers.utils.logging.enable_explicit_format() 275 | 276 | # Log on each process the small summary: 277 | logger.warning( 278 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n" 279 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 280 | ) 281 | logger.info(f"Training/evaluation parameters {training_args}") 282 | 283 | # Set seed before initializing model. 284 | transformers.set_seed(training_args.seed) 285 | 286 | return model_args, data_args, training_args, finetuning_args 287 | 288 | 289 | def prepare_data( 290 | model_args: ModelArguments, 291 | data_args: DataTrainingArguments 292 | ) -> Dataset: 293 | 294 | def checksum(file_path, hash): 295 | with open(file_path, "rb") as datafile: 296 | binary_data = datafile.read() 297 | sha1 = hashlib.sha1(binary_data).hexdigest() 298 | if sha1 != hash: 299 | logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) 300 | 301 | max_samples = data_args.max_samples 302 | all_datasets = [] # support multiple datasets 303 | 304 | for dataset_info in data_args.dataset_list: 305 | 306 | logger.info("Loading dataset {}...".format(dataset_info)) 307 | 308 | if dataset_info.load_from == "hf_hub": 309 | raw_datasets = load_dataset(dataset_info.dataset_name, cache_dir=model_args.cache_dir) 310 | elif dataset_info.load_from == "script": 311 | raw_datasets = load_dataset( 312 | os.path.join(data_args.dataset_dir, dataset_info.dataset_name), 313 | cache_dir=model_args.cache_dir 314 | ) 315 | elif dataset_info.load_from == "file": 316 | data_file = os.path.join(data_args.dataset_dir, dataset_info.file_name) # support json, jsonl and csv 317 | extension = dataset_info.file_name.split(".")[-1] 318 | 319 | if dataset_info.file_sha1 is not None: 320 | checksum(data_file, dataset_info.file_sha1) 321 | else: 322 | logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.") 323 | 324 | raw_datasets = load_dataset( 325 | extension, 326 | data_files=data_file, 327 | cache_dir=model_args.cache_dir, 328 | use_auth_token=True if model_args.use_auth_token else None 329 | ) 330 | else: 331 | raise NotImplementedError 332 | 333 | dataset = raw_datasets[data_args.split] 334 | 335 | if max_samples is not None: 336 | max_samples_temp = min(len(dataset), max_samples) 337 | dataset = dataset.select(range(max_samples_temp)) 338 | 339 | dummy_data = [None] * len(dataset) 340 | for column, column_name in [ 341 | ("prompt_column", "prompt"), 342 | ("query_column", "query"), 343 | ("response_column", "response"), 344 | ("history_column", "history") 345 | ]: # every dataset will have 4 columns same as each other 346 | if getattr(dataset_info, column) != column_name: 347 | if getattr(dataset_info, column): 348 | dataset = dataset.rename_column(getattr(dataset_info, column), column_name) 349 | else: # None or empty string 350 | dataset = dataset.add_column(column_name, dummy_data) 351 | all_datasets.append(dataset) 352 | 353 | if len(data_args.dataset_list) == 1: 354 | all_datasets = all_datasets[0] 355 | else: 356 | all_datasets = concatenate_datasets(all_datasets) 357 | 358 | return all_datasets 359 | 360 | 361 | def preprocess_data( 362 | dataset: Dataset, 363 | tokenizer: PreTrainedTokenizer, 364 | data_args: DataTrainingArguments, 365 | training_args: Seq2SeqTrainingArguments, 366 | stage: Optional[Literal["sft", "rwd", "ppo"]] = "sft" 367 | ) -> Dataset: 368 | 369 | column_names = list(dataset.column_names) 370 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 371 | 372 | def format_example(examples): # support question with a single answer or multiple answers 373 | for i in range(len(examples["prompt"])): 374 | if examples["prompt"][i] and examples["response"][i]: 375 | query, answer = examples["prompt"][i], examples["response"][i] 376 | if examples["query"][i]: 377 | query += examples["query"][i] 378 | if examples["history"][i]: 379 | prompt = "" 380 | history = examples["history"][i] 381 | for i, (old_query, response) in enumerate(history): 382 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) 383 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) 384 | else: 385 | prompt = query 386 | prompt = prefix + prompt 387 | yield prompt, answer 388 | 389 | def preprocess_function_train(examples): 390 | # build inputs with format `X [gMASK] [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] [BOS] Y [EOS]` 391 | model_inputs = {"input_ids": [], "labels": []} 392 | for prompt, answer in format_example(examples): 393 | source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) 394 | target_ids = tokenizer.encode(text=answer, add_special_tokens=False) 395 | 396 | if len(source_ids) > data_args.max_source_length - 1: # gmask token 397 | source_ids = source_ids[:data_args.max_source_length - 1] 398 | if len(target_ids) > data_args.max_target_length - 2: # bos and eos tokens 399 | target_ids = target_ids[:data_args.max_target_length - 2] 400 | 401 | input_ids = tokenizer.build_inputs_with_special_tokens(source_ids, target_ids) 402 | 403 | context_length = input_ids.index(tokenizer.bos_token_id) 404 | labels = [IGNORE_INDEX] * context_length + input_ids[context_length:] 405 | 406 | model_inputs["input_ids"].append(input_ids) 407 | model_inputs["labels"].append(labels) 408 | return model_inputs 409 | 410 | def preprocess_function_eval(examples): 411 | # build inputs with format `[PAD] ... [PAD] X [gMASK] [BOS]` and labels with format `Y [gMASK] [BOS]` 412 | # left-padding is needed for prediction, use the built-in function of the tokenizer 413 | inputs, targets = [], [] 414 | for prompt, answer in format_example(examples): 415 | inputs.append(prompt) 416 | targets.append(answer) 417 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True) 418 | labels = tokenizer(text_target=targets, max_length=data_args.max_target_length, truncation=True) # no padding 419 | if data_args.ignore_pad_token_for_loss: 420 | labels["input_ids"] = [ 421 | [(l_id if l_id != tokenizer.pad_token_id else IGNORE_INDEX) for l_id in label] for label in labels["input_ids"] 422 | ] 423 | model_inputs["labels"] = labels["input_ids"] 424 | return model_inputs 425 | 426 | def preprocess_function_train_pair(examples): 427 | # build input pairs with format `X [gMASK] [BOS] Y1 [EOS]` and `X [gMASK] [BOS] Y2 [EOS]` 428 | model_inputs = {"accept_ids": [], "reject_ids": []} 429 | for prompt, answer in format_example(examples): 430 | source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) 431 | accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) 432 | reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) 433 | 434 | if len(source_ids) > data_args.max_source_length - 1: 435 | source_ids = source_ids[:data_args.max_source_length - 1] 436 | if len(accept_ids) > data_args.max_target_length - 2: 437 | accept_ids = accept_ids[:data_args.max_target_length - 2] 438 | if len(reject_ids) > data_args.max_target_length - 2: 439 | reject_ids = reject_ids[:data_args.max_target_length - 2] 440 | 441 | accept_ids = tokenizer.build_inputs_with_special_tokens(source_ids, accept_ids) 442 | reject_ids = tokenizer.build_inputs_with_special_tokens(source_ids, reject_ids) 443 | 444 | model_inputs["accept_ids"].append(accept_ids) 445 | model_inputs["reject_ids"].append(reject_ids) 446 | return model_inputs 447 | 448 | def preprocess_function_train_ppo(examples): 449 | # build inputs with format `X [gMASK] [BOS]` 450 | model_inputs = {"input_ids": []} 451 | for prompt, _ in format_example(examples): 452 | source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) 453 | 454 | if len(source_ids) > data_args.max_source_length - 1: # gmask token 455 | source_ids = source_ids[:data_args.max_source_length - 1] 456 | 457 | input_ids = tokenizer.build_inputs_with_special_tokens(source_ids) 458 | model_inputs["input_ids"].append(input_ids) 459 | return model_inputs 460 | 461 | def print_sft_dataset_example(example): 462 | print("input_ids:\n{}".format(example["input_ids"])) 463 | print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) 464 | print("label_ids:\n{}".format(example["labels"])) 465 | print("labels:\n{}".format(tokenizer.decode(example["labels"]))) 466 | 467 | def print_pairwise_dataset_example(example): 468 | print("accept_ids:\n{}".format(example["accept_ids"])) 469 | print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"]))) 470 | print("reject_ids:\n{}".format(example["reject_ids"])) 471 | print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"]))) 472 | 473 | def print_ppo_dataset_example(example): 474 | print("input_ids:\n{}".format(example["input_ids"])) 475 | print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) 476 | 477 | if stage == "sft": 478 | preprocess_function = preprocess_function_train if training_args.do_train else preprocess_function_eval 479 | elif stage == "rwd": 480 | preprocess_function = preprocess_function_train_pair 481 | elif stage == "ppo": 482 | preprocess_function = preprocess_function_train_ppo 483 | 484 | with training_args.main_process_first(desc="dataset map pre-processing"): 485 | dataset = dataset.map( 486 | preprocess_function, 487 | batched=True, 488 | num_proc=data_args.preprocessing_num_workers, 489 | remove_columns=column_names, 490 | load_from_cache_file=not data_args.overwrite_cache, 491 | desc="Running tokenizer on dataset" 492 | ) 493 | 494 | if stage == "sft": 495 | print_sft_dataset_example(dataset[0]) 496 | elif stage == "rwd": 497 | print_pairwise_dataset_example(dataset[0]) 498 | elif stage == "ppo": 499 | print_ppo_dataset_example(dataset[0]) 500 | 501 | return dataset 502 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import Optional 4 | from dataclasses import dataclass, field 5 | 6 | 7 | CHATGLM_REPO_NAME = "THUDM/chatglm-6b" 8 | CHATGLM_LASTEST_HASH = "a8ede826cf1b62bd3c78bdfb3625c7c5d2048fbd" 9 | 10 | 11 | @dataclass 12 | class DatasetAttr: 13 | 14 | load_from: str 15 | dataset_name: Optional[str] = None 16 | file_name: Optional[str] = None 17 | file_sha1: Optional[str] = None 18 | 19 | def __post_init__(self): 20 | self.prompt_column = "instruction" 21 | self.query_column = "input" 22 | self.response_column = "output" 23 | self.history_column = None 24 | 25 | 26 | @dataclass 27 | class ModelArguments: 28 | """ 29 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune. 30 | """ 31 | model_name_or_path: Optional[str] = field( 32 | default=CHATGLM_REPO_NAME, 33 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} 34 | ) 35 | config_name: Optional[str] = field( 36 | default=None, 37 | metadata={"help": "Pretrained config name or path if not the same as model_name."} 38 | ) 39 | tokenizer_name: Optional[str] = field( 40 | default=None, 41 | metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."} 42 | ) 43 | cache_dir: Optional[str] = field( 44 | default=None, 45 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} 46 | ) 47 | use_fast_tokenizer: Optional[bool] = field( 48 | default=True, 49 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} 50 | ) 51 | model_revision: Optional[str] = field( 52 | default=CHATGLM_LASTEST_HASH, 53 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} 54 | ) 55 | use_auth_token: Optional[bool] = field( 56 | default=False, 57 | metadata={"help": "Will use the token generated when running `huggingface-cli login`."} 58 | ) 59 | quantization_bit: Optional[int] = field( 60 | default=None, 61 | metadata={"help": "The number of bits to quantize the model."} 62 | ) 63 | checkpoint_dir: Optional[str] = field( 64 | default=None, 65 | metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."} 66 | ) 67 | reward_model: Optional[str] = field( 68 | default=None, 69 | metadata={"help": "Path to the directory containing the checkpoints of the reward model."} 70 | ) 71 | 72 | def __post_init__(self): 73 | if self.checkpoint_dir is not None: # support merging lora weights 74 | self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] 75 | 76 | 77 | @dataclass 78 | class DataTrainingArguments: 79 | """ 80 | Arguments pertaining to what data we are going to input our model for training and evaluation. 81 | """ 82 | dataset: Optional[str] = field( 83 | default="alpaca_zh", 84 | metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."} 85 | ) 86 | dataset_dir: Optional[str] = field( 87 | default="data", 88 | metadata={"help": "The name of the folder containing datasets."} 89 | ) 90 | split: Optional[str] = field( 91 | default="train", 92 | metadata={"help": "Which dataset split to use for training and evaluation."} 93 | ) 94 | overwrite_cache: Optional[bool] = field( 95 | default=False, 96 | metadata={"help": "Overwrite the cached training and evaluation sets."} 97 | ) 98 | preprocessing_num_workers: Optional[int] = field( 99 | default=None, 100 | metadata={"help": "The number of processes to use for the preprocessing."} 101 | ) 102 | max_source_length: Optional[int] = field( 103 | default=512, 104 | metadata={"help": "The maximum total input sequence length after tokenization."} 105 | ) 106 | max_target_length: Optional[int] = field( 107 | default=512, 108 | metadata={"help": "The maximum total output sequence length after tokenization."} 109 | ) 110 | max_samples: Optional[int] = field( 111 | default=None, 112 | metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} 113 | ) 114 | num_beams: Optional[int] = field( 115 | default=None, 116 | metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} 117 | ) 118 | ignore_pad_token_for_loss: Optional[bool] = field( 119 | default=True, 120 | metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} 121 | ) 122 | source_prefix: Optional[str] = field( 123 | default=None, 124 | metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 125 | ) 126 | 127 | def __post_init__(self): # support mixing multiple datasets 128 | dataset_names = [ds.strip() for ds in self.dataset.split(",")] 129 | dataset_info = json.load(open(os.path.join(self.dataset_dir, "dataset_info.json"), "r")) 130 | 131 | self.dataset_list = [] 132 | for name in dataset_names: 133 | if name not in dataset_info: 134 | raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) 135 | 136 | if "hf_hub_url" in dataset_info[name]: 137 | dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) 138 | elif "script_url" in dataset_info[name]: 139 | dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) 140 | else: 141 | dataset_attr = DatasetAttr( 142 | "file", 143 | file_name=dataset_info[name]["file_name"], 144 | file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None 145 | ) 146 | 147 | if "columns" in dataset_info[name]: 148 | dataset_attr.prompt_column = dataset_info[name]["columns"]["prompt"] 149 | dataset_attr.query_column = dataset_info[name]["columns"]["query"] 150 | dataset_attr.response_column = dataset_info[name]["columns"]["response"] 151 | dataset_attr.history_column = dataset_info[name]["columns"]["history"] 152 | 153 | self.dataset_list.append(dataset_attr) 154 | 155 | 156 | @dataclass 157 | class FinetuningArguments: 158 | """ 159 | Arguments pertaining to which techniques we are going to fine-tuning with. 160 | """ 161 | finetuning_type: Optional[str] = field( 162 | default="lora", 163 | metadata={"help": "Which fine-tuning method to use."} 164 | ) 165 | num_layer_trainable: Optional[int] = field( 166 | default=3, 167 | metadata={"help": "Number of trainable layers for Freeze fine-tuning."} 168 | ) 169 | name_module_trainable: Optional[str] = field( 170 | default="mlp", 171 | metadata={"help": "Name of trainable modules for Freeze fine-tuning."} 172 | ) 173 | pre_seq_len: Optional[int] = field( 174 | default=16, 175 | metadata={"help": "Number of prefix tokens to use for P-tuning V2."} 176 | ) 177 | prefix_projection: Optional[bool] = field( 178 | default=False, 179 | metadata={"help": "Whether to add a project layer for the prefix in P-tuning V2 or not."} 180 | ) 181 | lora_rank: Optional[int] = field( 182 | default=8, 183 | metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} 184 | ) 185 | lora_alpha: Optional[float] = field( 186 | default=32.0, 187 | metadata={"help": "The scale factor for LoRA fine-tuning. (similar with the learning rate)"} 188 | ) 189 | lora_dropout: Optional[float] = field( 190 | default=0.1, 191 | metadata={"help": "Dropout rate for the LoRA fine-tuning."} 192 | ) 193 | lora_target: Optional[str] = field( 194 | default="query_key_value", 195 | metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules."} 196 | ) 197 | resume_lora_training: Optional[bool] = field( 198 | default=True, 199 | metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} 200 | ) 201 | plot_loss: Optional[bool] = field( 202 | default=False, 203 | metadata={"help": "Whether to plot the training loss after fine-tuning or not."} 204 | ) 205 | 206 | def __post_init__(self): 207 | self.lora_target = [target.strip() for target in self.lora_target.split(",")] # support custom target modules of LoRA 208 | 209 | if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 210 | trainable_layer_ids = [27-k for k in range(self.num_layer_trainable)] 211 | else: # fine-tuning the first n layers if num_layer_trainable < 0 212 | trainable_layer_ids = [k for k in range(-self.num_layer_trainable)] 213 | if self.name_module_trainable == "mlp": 214 | self.trainable_layers = ["layers.{:d}.mlp".format(idx) for idx in trainable_layer_ids] 215 | elif self.name_module_trainable == "qkv": 216 | self.trainable_layers = ["layers.{:d}.attention.query_key_value".format(idx) for idx in trainable_layer_ids] 217 | 218 | if self.finetuning_type not in ["none", "freeze", "p_tuning", "lora"]: 219 | raise NotImplementedError("Invalid fine-tuning method.") 220 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/utils/other.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import logging 6 | from typing import Dict, List, Optional 7 | 8 | from transformers import Seq2SeqTrainingArguments 9 | from transformers.trainer import TRAINER_STATE_NAME 10 | from transformers.modeling_utils import PreTrainedModel 11 | 12 | from peft.utils.other import WEIGHTS_NAME 13 | 14 | 15 | IGNORE_INDEX = -100 16 | VALUE_HEAD_FILE_NAME = "value_head.bin" 17 | FINETUNING_ARGS_NAME = "finetuning_args.bin" 18 | PREDICTION_FILE_NAME = "generated_predictions.txt" 19 | 20 | 21 | logger = logging.getLogger(__name__) # setup logging 22 | logger.setLevel(logging.INFO) 23 | logging.basicConfig( 24 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 25 | datefmt="%m/%d/%Y %H:%M:%S", 26 | handlers=[logging.StreamHandler(sys.stdout)], 27 | ) 28 | 29 | 30 | class AverageMeter: 31 | r""" 32 | Computes and stores the average and current value. 33 | """ 34 | def __init__(self): 35 | self.reset() 36 | 37 | def reset(self): 38 | self.val = 0 39 | self.avg = 0 40 | self.sum = 0 41 | self.count = 0 42 | 43 | def update(self, val, n=1): 44 | self.val = val 45 | self.sum += val * n 46 | self.count += n 47 | self.avg = self.sum / self.count 48 | 49 | 50 | # Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32 51 | # Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35 52 | def prepare_model_for_training( 53 | model: PreTrainedModel, 54 | output_embedding_layer_name: Optional[str] = "lm_head", 55 | use_gradient_checkpointing: Optional[bool] = True, 56 | layer_norm_names: List[str] = ["layernorm"] # for chatglm setting 57 | ) -> PreTrainedModel: 58 | 59 | for name, param in model.named_parameters(): 60 | if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): 61 | param.data = param.data.to(torch.float32) 62 | 63 | if use_gradient_checkpointing: 64 | model.enable_input_require_grads() 65 | model.gradient_checkpointing_enable() 66 | model.config.use_cache = False # turn off when gradient checkpointing is enabled 67 | 68 | if hasattr(model, output_embedding_layer_name): 69 | output_embedding_layer = getattr(model, output_embedding_layer_name) 70 | input_dtype = output_embedding_layer.weight.dtype 71 | 72 | class CastOutputToFloat(torch.nn.Sequential): 73 | 74 | def forward(self, x): 75 | return super().forward(x.to(input_dtype)).to(torch.float32) 76 | 77 | setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer)) 78 | 79 | return model 80 | 81 | 82 | def print_trainable_params(model: torch.nn.Module) -> None: 83 | trainable_params, all_param = 0, 0 84 | for param in model.parameters(): 85 | num_params = param.numel() 86 | # if using DS Zero 3 and the weights are initialized empty 87 | if num_params == 0 and hasattr(param, "ds_numel"): 88 | num_params = param.ds_numel 89 | all_param += num_params 90 | if param.requires_grad: 91 | trainable_params += num_params 92 | print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( 93 | trainable_params, all_param, 100 * trainable_params / all_param)) 94 | 95 | 96 | def filter_model_params(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # filter out freezed parameters 97 | state_dict = model.state_dict() 98 | filtered_state_dict = {} 99 | for k, v in model.named_parameters(): 100 | if v.requires_grad: 101 | filtered_state_dict[k] = state_dict[k] 102 | return filtered_state_dict 103 | 104 | 105 | def save_trainable_params(save_directory: os.PathLike, model: torch.nn.Module) -> None: 106 | if os.path.isfile(save_directory): 107 | raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file.") 108 | os.makedirs(save_directory, exist_ok=True) 109 | filtered_state_dict = filter_model_params(model) 110 | torch.save(filtered_state_dict, os.path.join(save_directory, WEIGHTS_NAME)) 111 | 112 | 113 | def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None: 114 | weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) 115 | if not os.path.exists(weights_file): 116 | raise ValueError(f"Provided path ({checkpoint_dir}) does not contain the pretrained weights.") 117 | model_state_dict = torch.load(weights_file) 118 | model.load_state_dict(model_state_dict, strict=False) # skip missing keys 119 | 120 | 121 | def save_valuehead_params(save_directory: os.PathLike, v_head: torch.nn.Module) -> None: 122 | if os.path.isfile(save_directory): 123 | raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file.") 124 | os.makedirs(save_directory, exist_ok=True) 125 | torch.save(v_head.state_dict(), os.path.join(save_directory, VALUE_HEAD_FILE_NAME)) 126 | 127 | 128 | def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None: 129 | valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME) 130 | if not os.path.exists(valuehead_file): 131 | raise ValueError(f"Provided path ({checkpoint_dir}) does not contain the valuehead weights.") 132 | valuehead_state_dict = torch.load(valuehead_file) 133 | model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"]) 134 | model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"]) 135 | model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"])) 136 | model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) 137 | 138 | 139 | def plot_loss(training_args: Seq2SeqTrainingArguments) -> None: 140 | import matplotlib.pyplot as plt 141 | FIGURE_NAME = "trainer_state.png" 142 | data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r")) 143 | train_steps, train_losses = [], [] 144 | for i in range(len(data["log_history"]) - 1): 145 | train_steps.append(data["log_history"][i]["step"]) 146 | train_losses.append(data["log_history"][i]["loss"]) 147 | plt.figure() 148 | plt.plot(train_steps, train_losses) 149 | plt.title("training loss of {}".format(training_args.output_dir)) 150 | plt.xlabel("step") 151 | plt.ylabel("training loss") 152 | plt.savefig(os.path.join(training_args.output_dir, FIGURE_NAME), format="png", transparent=True, dpi=300) 153 | print("Figure saved: {}".format(os.path.join(training_args.output_dir, FIGURE_NAME))) 154 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/utils/pairwise.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import logging 5 | from typing import Dict, Optional, Sequence 6 | 7 | from transformers import Trainer, DataCollatorWithPadding 8 | from transformers.trainer import TRAINING_ARGS_NAME 9 | from transformers.tokenization_utils import PreTrainedTokenizer 10 | 11 | from .config import FinetuningArguments 12 | 13 | from .other import ( 14 | save_trainable_params, 15 | save_valuehead_params, 16 | FINETUNING_ARGS_NAME 17 | ) 18 | 19 | 20 | logger = logging.getLogger(__name__) # setup logging 21 | logger.setLevel(logging.INFO) 22 | logging.basicConfig( 23 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 24 | datefmt="%m/%d/%Y %H:%M:%S", 25 | handlers=[logging.StreamHandler(sys.stdout)], 26 | ) 27 | 28 | 29 | class PairwiseDataCollatorForChatGLM(DataCollatorWithPadding): 30 | r""" 31 | Data collator for ChatGLM. It is capable of dynamically padding for batched data. 32 | 33 | Inspired by: https://github.com/tatsu-lab/stanford_alpaca/blob/65512697dc67779a6e53c267488aba0ec4d7c02a/train.py#L156 34 | """ 35 | def __init__( 36 | self, 37 | tokenizer: PreTrainedTokenizer, 38 | inference_mode: bool = False 39 | ): 40 | super().__init__(tokenizer, padding=True) 41 | self.inference_mode = inference_mode 42 | 43 | def __call__(self, features: Sequence[Dict[str, Sequence]]) -> Dict[str, torch.Tensor]: 44 | r""" 45 | Pads batched data to the longest sequence in the batch. We adopt right-padding for pairwise data. 46 | 47 | ChatGLM is able to generate attentions masks and position ids by itself. 48 | """ 49 | if self.inference_mode: 50 | raise NotImplementedError 51 | accept_ids, reject_ids = [[torch.tensor(feature[key]) for feature in features] for key in ("accept_ids", "reject_ids")] 52 | accept_ids = torch.nn.utils.rnn.pad_sequence(accept_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) 53 | reject_ids = torch.nn.utils.rnn.pad_sequence(reject_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) 54 | features = {"accept_ids": accept_ids, "reject_ids": reject_ids} 55 | return features 56 | 57 | class PairwiseTrainerForChatGLM(Trainer): 58 | r""" 59 | Inherits Trainer to compute pairwise loss. 60 | """ 61 | 62 | def __init__(self, finetuning_args: FinetuningArguments, *args, **kwargs): 63 | super().__init__(*args, **kwargs) 64 | self.finetuning_args = finetuning_args 65 | 66 | def compute_loss(self, model, inputs, return_outputs=False): 67 | r""" 68 | Computes pairwise loss. 69 | 70 | There are two different implmentations: 71 | [1] https://github.com/lvwerra/trl/blob/52fecee8839ad826ad1e6c83a95c99a4116e98d2/examples/summarization/scripts/reward_summarization.py#L181 72 | [2] https://github.com/microsoft/DeepSpeedExamples/blob/f4ad1d5721630185a9088565f9201929a8b1ffdf/applications/DeepSpeed-Chat/training/utils/model/reward_model.py#L37 73 | Now we adopt the first implementation. We will consider adopting the second implementation later. 74 | """ 75 | _, _, r_accept = model(input_ids=inputs["accept_ids"]) 76 | _, _, r_reject = model(input_ids=inputs["reject_ids"]) 77 | s_accept = r_accept.transpose(0, 1)[(inputs["accept_ids"] == self.tokenizer.eos_token_id).nonzero(as_tuple=True)] 78 | s_reject = r_reject.transpose(0, 1)[(inputs["reject_ids"] == self.tokenizer.eos_token_id).nonzero(as_tuple=True)] 79 | loss = -torch.log(torch.sigmoid(s_accept - s_reject)).mean() 80 | if return_outputs: 81 | return loss, {"r_accept": r_accept, "r_reject": r_reject} 82 | return loss 83 | 84 | def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None: 85 | r""" 86 | Saves trainable parameters as model checkpoints. Use `self.model.pretrained_model` to refer to the backbone model. 87 | 88 | Override to inject custom behavior. 89 | """ 90 | output_dir = output_dir if output_dir is not None else self.args.output_dir 91 | os.makedirs(output_dir, exist_ok=True) 92 | logger.info(f"Saving model checkpoint to {output_dir}") 93 | if hasattr(self.model.pretrained_model, "peft_config"): # LoRA 94 | self.model.pretrained_model.save_pretrained(output_dir) # only save peft weights with the built-in method 95 | else: # Freeze and P-Tuning 96 | save_trainable_params(output_dir, self.model.pretrained_model) 97 | if hasattr(self.model, "v_head"): 98 | save_valuehead_params(output_dir, self.model.v_head) # save valuehead weights 99 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 100 | torch.save(self.finetuning_args, os.path.join(output_dir, FINETUNING_ARGS_NAME)) 101 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/utils/ppo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import logging 6 | from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple 7 | 8 | from transformers import DataCollatorWithPadding, Seq2SeqTrainingArguments 9 | from transformers.trainer import TRAINING_ARGS_NAME, TRAINER_STATE_NAME 10 | from transformers.tokenization_utils import PreTrainedTokenizer 11 | 12 | from trl import PPOTrainer, AutoModelForCausalLMWithValueHead 13 | from trl.core import LengthSampler 14 | from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits 15 | 16 | from .config import FinetuningArguments 17 | 18 | from .other import ( 19 | AverageMeter, 20 | save_trainable_params, 21 | save_valuehead_params, 22 | FINETUNING_ARGS_NAME 23 | ) 24 | 25 | 26 | logger = logging.getLogger(__name__) # setup logging 27 | logger.setLevel(logging.INFO) 28 | logging.basicConfig( 29 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 30 | datefmt="%m/%d/%Y %H:%M:%S", 31 | handlers=[logging.StreamHandler(sys.stdout)], 32 | ) 33 | 34 | 35 | def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None: 36 | if target == "reward": 37 | valuehead_state_dict = model.v_head.state_dict() 38 | 39 | setattr(model, "origin_head_weight", valuehead_state_dict["summary.weight"]) 40 | setattr(model, "origin_head_bias", valuehead_state_dict["summary.bias"]) 41 | 42 | model.pretrained_model.set_adapter(target) 43 | model.v_head.load_state_dict({ 44 | "summary.weight": getattr(model, "{}_head_weight".format(target)), 45 | "summary.bias": getattr(model, "{}_head_bias".format(target)) 46 | }) 47 | 48 | 49 | @torch.no_grad() 50 | def compute_rewards( 51 | input_ids: torch.Tensor, # (batch size x seq len) with format `X [gMASK] [BOS] Y [EOS] [PAD] ... [PAD]` 52 | model: AutoModelForCausalLMWithValueHead, 53 | tokenizer: PreTrainedTokenizer 54 | ) -> torch.Tensor: 55 | 56 | replace_model(model, target="reward") 57 | 58 | _, _, values = model(input_ids=input_ids) 59 | values = values.transpose(0, 1) 60 | 61 | rewards = [] 62 | for i in range(input_ids.size(0)): 63 | eos_idx = (input_ids[i] == tokenizer.eos_token_id).nonzero() # Note: checking with eos_token is unsafe 64 | if len(eos_idx): 65 | eos_idx = eos_idx[0].item() 66 | else: 67 | eos_idx = input_ids.size(1) - 1 68 | rewards.append(values[i][eos_idx]) 69 | rewards = torch.stack(rewards, dim=0) 70 | 71 | replace_model(model, target="default") 72 | 73 | return rewards 74 | 75 | 76 | def cast_layernorm_dtype( 77 | model: AutoModelForCausalLMWithValueHead, 78 | layer_norm_names: List[str] = ["layernorm"], # for chatglm setting 79 | layer_norm_params: Optional[Dict[str, torch.Tensor]] = None 80 | ) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: 81 | 82 | layer_norm_state_dict = {} 83 | 84 | for name, param in model.named_parameters(): 85 | if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): 86 | if layer_norm_params is not None: 87 | param.data = layer_norm_params[name] # restore float32 weights 88 | else: 89 | layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability 90 | param.data = param.data.to(torch.float16) 91 | 92 | return model, layer_norm_state_dict 93 | 94 | 95 | class PPODataCollatorForChatGLM(DataCollatorWithPadding): 96 | r""" 97 | Data collator for ChatGLM. It is capable of dynamically padding for batched data. 98 | 99 | Inspired by: https://github.com/tatsu-lab/stanford_alpaca/blob/65512697dc67779a6e53c267488aba0ec4d7c02a/train.py#L156 100 | """ 101 | def __init__( 102 | self, 103 | tokenizer: PreTrainedTokenizer, 104 | min_input_length: int, 105 | max_input_length: int, 106 | inference_mode: bool = False, 107 | ): 108 | super().__init__(tokenizer, padding=True) 109 | self.inference_mode = inference_mode 110 | if min_input_length < max_input_length: 111 | self.input_size = LengthSampler(min_input_length, max_input_length) 112 | else: 113 | self.input_size = lambda: max_input_length # always use max_input_length 114 | 115 | def __call__(self, features: Sequence[Dict[str, Sequence]]) -> Dict[str, torch.Tensor]: 116 | r""" 117 | Pads batched data to the longest sequence in the batch. We adopt left-padding for ppo data. 118 | 119 | Equips with a length sampler to generate sequences with variable lengths. 120 | 121 | ChatGLM is able to generate attentions masks and position ids by itself. 122 | """ 123 | if self.inference_mode: 124 | raise NotImplementedError 125 | input_ids = [torch.tensor(feature["input_ids"][:self.input_size()]).flip(0) for feature in features] 126 | input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) 127 | features = {"input_ids": input_ids.flip(-1)} 128 | return features 129 | 130 | class PPOTrainerForChatGLM(PPOTrainer): 131 | r""" 132 | Inherits PPOTrainer. 133 | """ 134 | 135 | def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, *args, **kwargs): 136 | super().__init__(*args, **kwargs) 137 | self.steps = 0 138 | self.loss_meter = AverageMeter() 139 | self.reward_meter = AverageMeter() 140 | self.trainer_state = {"log_history": []} 141 | self.training_args = training_args 142 | self.finetuning_args = finetuning_args 143 | 144 | def generate( 145 | self, 146 | query_tensor: torch.Tensor, # (batch size x seq len) 147 | length_sampler: Callable = None, 148 | return_prompt: bool = True, 149 | **generation_kwargs, 150 | ) -> torch.Tensor: 151 | r""" 152 | Generate response with the model given the query tensor. 153 | 154 | Inspired by: https://github.com/lvwerra/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/trl/trainer/ppo_trainer.py#L387 155 | """ 156 | 157 | self.model, layer_norm_params = cast_layernorm_dtype(self.model) 158 | 159 | if length_sampler is not None: 160 | generation_kwargs["max_new_tokens"] = length_sampler() 161 | 162 | response = self.accelerator.unwrap_model(self.model).generate( 163 | input_ids=query_tensor, **generation_kwargs 164 | ) 165 | 166 | # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop 167 | # Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273 168 | if self.model.pretrained_model.generation_config._from_model_config: 169 | self.model.pretrained_model.generation_config._from_model_config = False 170 | 171 | self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params) 172 | 173 | if not return_prompt and not self.is_encoder_decoder: 174 | return response[:, query_tensor.size(1):] 175 | return response 176 | 177 | def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor): 178 | input_ids = [] 179 | for query, response in zip(queries, responses): # query is left-padded, response is right-padded 180 | start = (query != self.tokenizer.pad_token_id).nonzero()[0].item() 181 | input_ids.append(torch.cat((query[start:], response, query[:start]))) # change to right-padding 182 | 183 | input_data = self.data_collator([{"input_ids": ids} for ids in input_ids]).to(self.current_device) 184 | input_data.pop("labels", None) # we don't want to compute LM losses 185 | 186 | return input_data 187 | 188 | @PPODecorators.empty_cuda_cache() 189 | def batched_forward_pass( 190 | self, 191 | model: AutoModelForCausalLMWithValueHead, 192 | queries: torch.Tensor, 193 | responses: torch.Tensor, 194 | model_inputs: dict, 195 | ): 196 | r""" 197 | Calculate model outputs in multiple batches. 198 | 199 | Override to inject custom behavior. 200 | """ 201 | bs = len(queries) 202 | fbs = self.config.mini_batch_size 203 | all_logprobs = [] 204 | all_logits = [] 205 | all_masks = [] 206 | all_values = [] 207 | 208 | for i in range(int(bs / fbs)): 209 | input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} 210 | 211 | input_ids = input_kwargs["input_ids"] 212 | logits, _, values = model(input_ids=input_ids) # chatglm only needs input_ids 213 | logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) 214 | 215 | values = values.transpose(0, 1) 216 | masks = torch.zeros_like(input_ids) 217 | 218 | for j in range(fbs): 219 | start = (input_ids[j] == self.tokenizer.bos_token_id).nonzero()[0].item() 220 | end = (input_ids[j] == self.tokenizer.eos_token_id).nonzero() 221 | if len(end): 222 | end = end[0].item() 223 | else: 224 | end = masks.size(1) 225 | masks[j][start:end] = 1 226 | 227 | all_logits.append(logits) 228 | all_values.append(values) 229 | all_logprobs.append(logprobs) 230 | all_masks.append(masks) 231 | 232 | return ( 233 | torch.cat(all_logprobs), 234 | torch.cat(all_logits)[:, :-1], 235 | torch.cat(all_values)[:, :-1], 236 | torch.cat(all_masks)[:, :-1], 237 | ) 238 | 239 | def update_stats(self, stats: Dict[str, Any], batch: Dict[str, torch.Tensor], rewards: torch.Tensor) -> None: 240 | self.steps += 1 241 | self.loss_meter.update(stats["ppo/loss/total"]) 242 | self.reward_meter.update(rewards.sum().item(), n=rewards.size(0)) 243 | if self.steps % self.training_args.logging_steps == 0: 244 | print("{{'loss': {:.4f}, 'reward': {:.4f}, 'learning_rate': {:}}}".format( 245 | self.loss_meter.avg, self.reward_meter.avg, stats["ppo/learning_rate"] 246 | )) 247 | self.trainer_state["log_history"].append({ 248 | "loss": self.loss_meter.avg, 249 | "reward": self.reward_meter.avg, 250 | "step": self.steps 251 | }) 252 | self.loss_meter.reset() 253 | self.reward_meter.reset() 254 | 255 | def save_state(self, output_dir: Optional[str] = None) -> None: 256 | r""" 257 | Saves trainer state. 258 | """ 259 | output_dir = output_dir if output_dir is not None else self.training_args.output_dir 260 | os.makedirs(output_dir, exist_ok=True) 261 | json.dump(self.trainer_state, open(os.path.join(output_dir, TRAINER_STATE_NAME), "w", encoding="utf-8", newline="\n")) 262 | 263 | def save_model(self, output_dir: Optional[str] = None) -> None: 264 | r""" 265 | Saves trainable parameters as model checkpoints. We use `self.model.pretrained_model` to refer to the backbone model. 266 | 267 | Override to inject custom behavior. 268 | """ 269 | output_dir = output_dir if output_dir is not None else self.training_args.output_dir 270 | os.makedirs(output_dir, exist_ok=True) 271 | logger.info(f"Saving model checkpoint to {output_dir}") 272 | if hasattr(self.model.pretrained_model, "peft_config"): # LoRA 273 | self.model.pretrained_model.save_pretrained(output_dir) # only save peft weights with the built-in method 274 | else: # Freeze and P-Tuning 275 | save_trainable_params(output_dir, self.model.pretrained_model) 276 | if hasattr(self.model, "v_head"): 277 | save_valuehead_params(output_dir, self.model.v_head) # save valuehead weights 278 | torch.save(self.training_args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 279 | torch.save(self.finetuning_args, os.path.join(output_dir, FINETUNING_ARGS_NAME)) 280 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/utils/seq2seq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import logging 6 | import numpy as np 7 | from dataclasses import dataclass 8 | from typing import Any, Dict, List, Optional, Sequence, Tuple, Union 9 | 10 | from transformers import Seq2SeqTrainer, DataCollatorForSeq2Seq 11 | from transformers.trainer import PredictionOutput, TRAINING_ARGS_NAME 12 | from transformers.deepspeed import is_deepspeed_zero3_enabled 13 | from transformers.modeling_utils import PreTrainedModel 14 | from transformers.tokenization_utils import PreTrainedTokenizer 15 | 16 | import jieba 17 | from rouge_chinese import Rouge 18 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 19 | 20 | from .config import FinetuningArguments 21 | 22 | from .other import ( 23 | save_trainable_params, 24 | IGNORE_INDEX, 25 | FINETUNING_ARGS_NAME, 26 | PREDICTION_FILE_NAME 27 | ) 28 | 29 | 30 | logger = logging.getLogger(__name__) # setup logging 31 | logger.setLevel(logging.INFO) 32 | logging.basicConfig( 33 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 34 | datefmt="%m/%d/%Y %H:%M:%S", 35 | handlers=[logging.StreamHandler(sys.stdout)], 36 | ) 37 | 38 | 39 | # Note: The ChatGLM tokenizer assigns False on token to be attended in attention mask. In general settings, it should be True. 40 | # Refer to: https://huggingface.co/THUDM/chatglm-6b/blob/6650ae3a53c28fc176d06762ca80b05d5ab3792b/tokenization_chatglm.py#L401 41 | class Seq2SeqDataCollatorForChatGLM(DataCollatorForSeq2Seq): 42 | r""" 43 | Data collator for ChatGLM. It is capable of dynamically padding for batched data. 44 | 45 | Inspired by: https://github.com/tatsu-lab/stanford_alpaca/blob/65512697dc67779a6e53c267488aba0ec4d7c02a/train.py#L156 46 | """ 47 | def __init__( 48 | self, 49 | tokenizer: PreTrainedTokenizer, 50 | model: PreTrainedModel, 51 | ignore_pad_token_for_loss: bool, 52 | inference_mode: bool = False 53 | ): 54 | label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id 55 | super().__init__(tokenizer, model=model, label_pad_token_id=label_pad_token_id, padding=True) 56 | self.label_pad_token_id = label_pad_token_id 57 | self.inference_mode = inference_mode 58 | 59 | def __call__(self, features: Sequence[Dict[str, Sequence]]) -> Dict[str, torch.Tensor]: 60 | r""" 61 | Pads batched data to the longest sequence in the batch. 62 | 63 | ChatGLM is able to generate attentions masks and position ids by itself. 64 | """ 65 | if self.inference_mode: # evaluation set adopts left-padding while training set adopts right-padding 66 | return super().__call__(features) 67 | input_ids, labels = [[torch.tensor(feature[key]) for feature in features] for key in ("input_ids", "labels")] 68 | input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) 69 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=self.label_pad_token_id) 70 | features = {"input_ids": input_ids, "labels": labels} 71 | return features 72 | 73 | 74 | @dataclass 75 | class ComputeMetrics: 76 | r""" 77 | Wraps the tokenizer into metric functions, used in Seq2SeqTrainerForChatGLM. 78 | 79 | Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307 80 | """ 81 | 82 | tokenizer: PreTrainedTokenizer 83 | 84 | def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: 85 | r""" 86 | Uses the model predictions to compute metrics. 87 | """ 88 | preds, labels = eval_preds 89 | if isinstance(preds, tuple): 90 | preds = preds[0] 91 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) 92 | # Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them if ignore_pad_token_for_loss=True. 93 | labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) 94 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) 95 | 96 | score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} 97 | for pred, label in zip(decoded_preds, decoded_labels): 98 | hypothesis = list(jieba.cut(pred)) 99 | reference = list(jieba.cut(label)) 100 | 101 | if len(" ".join(hypothesis).split()) == 0: 102 | result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} 103 | else: 104 | rouge = Rouge() 105 | scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) 106 | result = scores[0] 107 | 108 | for k, v in result.items(): 109 | score_dict[k].append(round(v["f"] * 100, 4)) 110 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) 111 | score_dict["bleu-4"].append(round(bleu_score * 100, 4)) 112 | 113 | return {k: float(np.mean(v)) for k, v in score_dict.items()} 114 | 115 | 116 | class Seq2SeqTrainerForChatGLM(Seq2SeqTrainer): 117 | r""" 118 | Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE. 119 | """ 120 | 121 | def __init__(self, finetuning_args: FinetuningArguments, *args, **kwargs): 122 | super().__init__(*args, **kwargs) 123 | self.finetuning_args = finetuning_args 124 | 125 | def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None: 126 | r""" 127 | Saves trainable parameters as model checkpoints. 128 | 129 | Override to inject custom behavior. 130 | """ 131 | output_dir = output_dir if output_dir is not None else self.args.output_dir 132 | os.makedirs(output_dir, exist_ok=True) 133 | logger.info(f"Saving model checkpoint to {output_dir}") 134 | if hasattr(self.model, "peft_config"): # LoRA 135 | self.model.save_pretrained(output_dir) # only save peft weights with the built-in method 136 | else: 137 | save_trainable_params(output_dir, self.model) # Freeze and P-Tuning 138 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 139 | torch.save(self.finetuning_args, os.path.join(output_dir, FINETUNING_ARGS_NAME)) 140 | 141 | def prediction_step( 142 | self, 143 | model: torch.nn.Module, 144 | inputs: Dict[str, Union[torch.Tensor, Any]], 145 | prediction_loss_only: bool, 146 | ignore_keys: Optional[List[str]] = None 147 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 148 | r""" 149 | Performs an evaluation step on `model` using `inputs` for ChatGLM. 150 | 151 | Override to inject custom behavior. It is not directly used by external scripts. 152 | """ 153 | # Override to inject custom bevavior. 154 | if not self.args.predict_with_generate or prediction_loss_only: 155 | return super().prediction_step( 156 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 157 | ) 158 | 159 | has_labels = "labels" in inputs 160 | inputs = self._prepare_inputs(inputs) 161 | 162 | gen_kwargs = self._gen_kwargs.copy() 163 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: 164 | gen_kwargs["max_length"] = self.model.config.max_length 165 | gen_kwargs["num_beams"] = gen_kwargs["num_beams"] \ 166 | if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams 167 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False 168 | gen_kwargs["synced_gpus"] = gen_kwargs["synced_gpus"] \ 169 | if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus 170 | 171 | if "attention_mask" in inputs: 172 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) 173 | if "position_ids" in inputs: 174 | gen_kwargs["position_ids"] = inputs.get("position_ids", None) 175 | if "global_attention_mask" in inputs: 176 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) 177 | 178 | # prepare generation inputs 179 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: 180 | generation_inputs = inputs[self.model.encoder.main_input_name] 181 | else: 182 | generation_inputs = inputs[self.model.main_input_name] 183 | 184 | gen_kwargs["input_ids"] = generation_inputs 185 | generated_tokens = self.model.generate(**gen_kwargs) 186 | generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] # important for ChatGLM 187 | 188 | # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop 189 | # Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273 190 | if self.model.generation_config._from_model_config: 191 | self.model.generation_config._from_model_config = False 192 | 193 | # Retrieves GenerationConfig from model.generation_config 194 | gen_config = self.model.generation_config 195 | # in case the batch is shorter than max length, the output should be padded 196 | if generated_tokens.shape[-1] < gen_config.max_length: 197 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length) 198 | elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1: 199 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1) 200 | 201 | loss = None 202 | 203 | if self.args.prediction_loss_only: 204 | return loss, None, None 205 | 206 | if has_labels: 207 | labels = inputs["labels"] 208 | if labels.shape[-1] < gen_config.max_length: 209 | labels = self._pad_tensors_to_max_len(labels, gen_config.max_length) 210 | elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1: 211 | labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1) 212 | else: 213 | labels = None 214 | 215 | return loss, generated_tokens, labels 216 | 217 | def save_predictions( 218 | self, 219 | predict_results: PredictionOutput, 220 | tokenizer: PreTrainedTokenizer 221 | ) -> None: 222 | r""" 223 | Saves model predictions to `output_dir`. 224 | 225 | A custom behavior that not contained in Seq2SeqTrainer. 226 | """ 227 | if self.is_world_process_zero(): 228 | if self.args.predict_with_generate: 229 | predictions = tokenizer.batch_decode(predict_results.predictions, skip_special_tokens=True) 230 | predictions = [pred.strip() for pred in predictions] 231 | labels = tokenizer.batch_decode(predict_results.label_ids, skip_special_tokens=True) 232 | labels = [label.strip() for label in labels] 233 | output_prediction_file = os.path.join(self.args.output_dir, PREDICTION_FILE_NAME) 234 | logger.info(f"Saving prediction results to {output_prediction_file}") 235 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 236 | res = [] 237 | for pred, label in zip(predictions, labels): 238 | res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) 239 | writer.write("\n".join(res)) 240 | -------------------------------------------------------------------------------- /MedQA-ChatGLM/web_demo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Implement user interface in browser for ChatGLM fine-tuned with PEFT. 3 | # This code is largely borrowed from https://github.com/THUDM/ChatGLM-6B/blob/main/web_demo.py 4 | 5 | 6 | import gradio as gr 7 | import mdtex2html 8 | 9 | from utils import ModelArguments, load_pretrained 10 | from transformers import HfArgumentParser 11 | 12 | 13 | parser = HfArgumentParser(ModelArguments) 14 | model_args, = parser.parse_args_into_dataclasses() 15 | model, tokenizer = load_pretrained(model_args) 16 | model = model.cuda() 17 | model.eval() 18 | 19 | 20 | """Override Chatbot.postprocess""" 21 | 22 | def postprocess(self, y): 23 | if y is None: 24 | return [] 25 | for i, (message, response) in enumerate(y): 26 | y[i] = ( 27 | None if message is None else mdtex2html.convert((message)), 28 | None if response is None else mdtex2html.convert(response), 29 | ) 30 | return y 31 | 32 | 33 | gr.Chatbot.postprocess = postprocess 34 | 35 | 36 | def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT 37 | lines = text.split("\n") 38 | lines = [line for line in lines if line != ""] 39 | count = 0 40 | for i, line in enumerate(lines): 41 | if "```" in line: 42 | count += 1 43 | items = line.split('`') 44 | if count % 2 == 1: 45 | lines[i] = f'
'
46 | else:
47 | lines[i] = f'
'
48 | else:
49 | if i > 0:
50 | if count % 2 == 1:
51 | line = line.replace("`", "\`")
52 | line = line.replace("<", "<")
53 | line = line.replace(">", ">")
54 | line = line.replace(" ", " ")
55 | line = line.replace("*", "*")
56 | line = line.replace("_", "_")
57 | line = line.replace("-", "-")
58 | line = line.replace(".", ".")
59 | line = line.replace("!", "!")
60 | line = line.replace("(", "(")
61 | line = line.replace(")", ")")
62 | line = line.replace("$", "$")
63 | lines[i] = "* 实验是在Linux系统,A100 (1X, 80GB)上进行的
172 |