├── README.md ├── data ├── ablation_data │ ├── em_es_ia_sr_re │ │ ├── empathetic_dialogue_valid.json │ │ ├── test │ │ │ └── empathetic_dialogue_test.json │ │ └── train │ │ │ └── empathetic_dialogue_train.json │ ├── em_sr_re │ │ ├── empathetic_dialogue_valid.json │ │ ├── test │ │ │ └── empathetic_dialogue_test.json │ │ └── train │ │ │ └── empathetic_dialogue_train.json │ ├── re │ │ ├── empathetic_dialogue_valid.json │ │ ├── test │ │ │ └── empathetic_dialogue_test.json │ │ └── train │ │ │ └── empathetic_dialogue_train.json │ └── sr_re │ │ ├── empathetic_dialogue_valid.json │ │ ├── test │ │ └── empathetic_dialogue_test.json │ │ └── train │ │ └── empathetic_dialogue_train.json ├── test.json ├── train.json └── val.json ├── merge.py ├── requirements.txt ├── scripts ├── supervised_finetune_llama2_cot.sh ├── supervised_finetune_llama2_cot_ablation.sh ├── test_llama2_chat_sft_cot.sh └── test_llama2_inference_cot.sh ├── supervised_finetuning_cot.py ├── test_llama2_chat_cot.py └── test_llama2_inference_cot.py /README.md: -------------------------------------------------------------------------------- 1 | # ESCoT: Towards Interpretable Emotional Support Dialogue Systems 2 | 3 | venue status 4 | 5 | This is the repository of our ACL 2024 main paper "[**ESCoT: Towards Interpretable Emotional Support Dialogue Systems**](https://aclanthology.org/2024.acl-long.723/)". 6 | 7 | ## ESD-CoT Dataset 8 | 9 | Our ESD-CoT dataset is organized under the `data` folder and is split into three JSON files: `train`, `val`, and `test`. Each file contains samples structured as follows: 10 | 11 | ```json 12 | { 13 | "id": , 14 | "original_data": { 15 | "dialog": [ 16 | { 17 | "speaker": "seeker", 18 | "content": "Hi, I'm having a really hard time managing my schoolwork and extracurricular activities. I feel like there's just not enough hours in the day." 19 | }, 20 | ... 21 | { 22 | "speaker": "seeker", 23 | "content": "Yeah, I can try that." 24 | } 25 | ], 26 | "strategy": "Providing Suggestions", 27 | "response": "Great, and let's touch base next week to see if the list has been helpful. In the meantime, have you considered talking to your teacher or a guidance counselor about feeling overwhelmed?" 28 | }, 29 | "cot_data": { 30 | "emotion": "The seeker feels overwhelmed and stretched thin.", 31 | "emotion_stimuli": "The seeker is struggling to manage schoolwork...", 32 | "individual_appraisal": "The seeker thinks they are not able to do anything well...", 33 | "recognized_strategy": "Providing Suggestions", 34 | "strategy_reason": "To address the seeker's feeling of being overwhelmed and..." 35 | } 36 | } 37 | ``` 38 | Additionally, we provide instructional format training data in the `data/ablation_data` folder. 39 | 40 | ## Model Training 41 | 42 | ### Download the pretrained models 43 | Download the [**LLAMA2-7B-CHAT**](https://huggingface.co/meta-llama/Llama-2-7b-hf) model. 44 | 45 | The training of LLAMA2-CHAT model is based on the [**SFT trainer of Transformer Reinforcement Learning**](https://github.com/huggingface/trl). 46 | 47 | ### Train Model 48 | Run bash `scripts/supervised_finetune_llama2_cot.sh` to train your model. 49 | 50 | Run bash `scripts/supervised_finetune_llama2_cot_ablation.sh` for Ablation Study model training. 51 | 52 | ### Test Model 53 | Run bash `scripts/test_llama2_chat_sft_cot.sh` or `scripts/test_llama2_inference_cot.sh`. 54 | 55 | ## Cite 56 | If you use our codes or your research is related to our work, please kindly cite our paper: 57 | ```bib 58 | @inproceedings{zhang-etal-2024-escot, 59 | title = "{ESC}o{T}: Towards Interpretable Emotional Support Dialogue Systems", 60 | author = "Zhang, Tenggan and Zhang, Xinjie and Zhao, Jinming and Zhou, Li and Jin, Qin", 61 | booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 62 | year = "2024" 63 | } 64 | ``` 65 | 66 | Please contact zhangxinjie827@ruc.edu.cn for any problems. -------------------------------------------------------------------------------- /merge.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from peft import PeftModel 5 | from transformers import ( 6 | AutoModelForCausalLM, 7 | AutoTokenizer, 8 | LlamaTokenizer 9 | ) 10 | 11 | DEFAULT_PAD_TOKEN = "[PAD]" 12 | DEFAULT_EOS_TOKEN = "" 13 | DEFAULT_BOS_TOKEN = "" 14 | DEFAULT_UNK_TOKEN = "" 15 | 16 | 17 | def merge_llm_with_lora(base_model_name, adapter_model_name, output_name, push_to_hub=False): 18 | base_model = AutoModelForCausalLM.from_pretrained( 19 | base_model_name, 20 | return_dict=True, 21 | torch_dtype=torch.bfloat16 22 | ) 23 | 24 | model = PeftModel.from_pretrained(base_model, adapter_model_name) 25 | model = model.merge_and_unload() 26 | 27 | if "decapoda" in base_model_name.lower(): 28 | tokenizer = LlamaTokenizer.from_pretrained(base_model_name) 29 | tokenizer.add_special_tokens( 30 | { 31 | "eos_token": DEFAULT_EOS_TOKEN, 32 | "bos_token": DEFAULT_BOS_TOKEN, 33 | "unk_token": DEFAULT_UNK_TOKEN, 34 | "pad_token": DEFAULT_PAD_TOKEN, 35 | } 36 | ) 37 | else: 38 | tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_fast=False) 39 | 40 | if push_to_hub: 41 | print(f"Saving to hub ...") 42 | model.push_to_hub(f"{base_model_name}-merged", use_temp_dir=False, private=True) 43 | tokenizer.push_to_hub(f"{base_model_name}-merged", use_temp_dir=False, private=True) 44 | else: 45 | output_name = os.path.join(output_name, "final_checkpoint-merged") 46 | model.save_pretrained(output_name) 47 | tokenizer.save_pretrained(output_name) 48 | print(f"Model saved to {output_name}") 49 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4.0 2 | datasets>=1.17.0 3 | transformers>=4.28.0 4 | accelerate 5 | evaluate 6 | # git+https://github.com/huggingface/peft.git 7 | # git+https://github.com/lvwerra/trl.git 8 | peft 9 | trl 10 | tqdm 11 | sentencepiece 12 | bitsandbytes 13 | wandb -------------------------------------------------------------------------------- /scripts/supervised_finetune_llama2_cot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | run_name='supervised_llama2_cot' 6 | 7 | torchrun --nnodes 1 --nproc_per_node 4 supervised_finetuning_cot.py \ 8 | --base_model '/datassd2/ztg/pretrained_models/Llama-2-7b-chat-hf' \ 9 | --dataset_name './data/ablation_data/em_es_ia_sr_re' \ 10 | --lr_scheduler_type 'cosine' \ 11 | --learning_rate 1e-5 \ 12 | --max_steps 10000 \ 13 | --save_freq 500 \ 14 | --seq_length 2048 \ 15 | --batch_size 8 \ 16 | --run_name $run_name \ 17 | --output_dir './checkpoints/cot/'$run_name 18 | -------------------------------------------------------------------------------- /scripts/supervised_finetune_llama2_cot_ablation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | base_model='./pretrained_models/Llama-2-7b-chat-hf' 6 | dataset_base='./data/ablation_data/' 7 | output_base='./checkpoints/cot/' 8 | lr_scheduler_type='cosine' 9 | learning_rate=5e-5 10 | max_steps=380 11 | save_freq=38 12 | eval_freq=380 13 | seq_length=2048 14 | batch_size=8 15 | 16 | settings=("em_es_ia_sr_re" "em_sr_re" "sr_re" "re") 17 | 18 | for setting in "${settings[@]}" 19 | do 20 | run_name="supervised_llama2_cot_ablation_${setting}" 21 | 22 | dataset_name="${dataset_base}${setting}" 23 | 24 | output_dir="${output_base}${run_name}" 25 | 26 | echo "Running setting: $setting" 27 | echo "Dataset path: $dataset_name" 28 | echo "Output path: $output_dir" 29 | 30 | torchrun --nnodes 1 --nproc_per_node 4 supervised_finetuning_cot.py \ 31 | --base_model "$base_model" \ 32 | --dataset_name "$dataset_name" \ 33 | --lr_scheduler_type "$lr_scheduler_type" \ 34 | --learning_rate "$learning_rate" \ 35 | --max_steps "$max_steps" \ 36 | --save_freq "$save_freq" \ 37 | --eval_freq "$eval_freq" \ 38 | --seq_length "$seq_length" \ 39 | --batch_size "$batch_size" \ 40 | --run_name "$run_name" \ 41 | --output_dir "$output_dir" \ 42 | --save_total_limit 100 \ 43 | --seed 1104 44 | done 45 | -------------------------------------------------------------------------------- /scripts/test_llama2_chat_sft_cot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the model path and GPU id as variables 4 | MODEL_PATH="change_to_checkpoint_path" 5 | GPU_ID=0 6 | 7 | # Execute the Python script with the model path and GPU id as arguments 8 | python test_llama2_chat_cot.py --model_path $MODEL_PATH --gpu_id $GPU_ID 9 | -------------------------------------------------------------------------------- /scripts/test_llama2_inference_cot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the model path and GPU id as variables 4 | MODEL_PATH="change_to_checkpoint_path" 5 | JSON_PATH="./data/ablation_data/em_es_ia_sr_re/empathetic_dialogue_valid.json" 6 | GPU_ID=0 7 | 8 | # Execute the Python script with the model path and GPU id as arguments 9 | python test_llama2_inference_cot.py --model_path $MODEL_PATH --gpu_id $GPU_ID --json_path $JSON_PATH 10 | -------------------------------------------------------------------------------- /supervised_finetuning_cot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from tqdm import tqdm 5 | from accelerate import Accelerator 6 | from datasets import load_dataset 7 | from peft import LoraConfig 8 | from transformers import ( 9 | AutoModelForCausalLM, 10 | AutoTokenizer, 11 | LlamaTokenizer, 12 | TrainingArguments, 13 | logging, 14 | set_seed 15 | ) 16 | from trl import SFTTrainer 17 | from trl.trainer import ConstantLengthDataset 18 | 19 | from utils.merge import merge_llm_with_lora 20 | 21 | 22 | def get_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--base_model", type=str, default="") 25 | parser.add_argument("--dataset_name", type=str, default="./data/alpaca_gpt4_data.json") 26 | parser.add_argument("--split", type=str, default="train") 27 | parser.add_argument("--size_valid_set", type=int, default=4000) 28 | parser.add_argument("--streaming", action="store_true", default=False) 29 | parser.add_argument("--shuffle_buffer", type=int, default=5000) 30 | 31 | parser.add_argument("--seq_length", type=int, default=1024) 32 | parser.add_argument("--max_steps", type=int, default=10000) 33 | parser.add_argument("--batch_size", type=int, default=16) 34 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 35 | parser.add_argument("--eos_token_id", type=int, default=49152) 36 | 37 | parser.add_argument("--lora_r", type=int, default=16) 38 | parser.add_argument("--lora_alpha", type=int, default=32) 39 | parser.add_argument("--lora_dropout", type=float, default=0.05) 40 | parser.add_argument("--lora_target_modules", type=str, default=None) 41 | 42 | parser.add_argument("--learning_rate", type=float, default=1e-4) 43 | parser.add_argument("--lr_scheduler_type", type=str, default="linear") 44 | parser.add_argument("--num_warmup_steps", type=int, default=100) 45 | parser.add_argument("--weight_decay", type=float, default=0.05) 46 | parser.add_argument("--warmup_ratio", type=float, default=0.) 47 | 48 | parser.add_argument("--local_rank", type=int, default=0) 49 | parser.add_argument("--fp16", action="store_true", default=False) 50 | parser.add_argument("--no_bf16", action="store_false", default=True) 51 | parser.add_argument("--no_gradient_checkpointing", action="store_false", default=True) 52 | parser.add_argument("--seed", type=int, default=1103) 53 | parser.add_argument("--num_workers", type=int, default=None) 54 | parser.add_argument("--output_dir", type=str, default="./checkpoints/supervised_llama/") 55 | parser.add_argument("--log_freq", type=int, default=1) 56 | parser.add_argument("--eval_freq", type=int, default=1000) 57 | parser.add_argument("--save_freq", type=int, default=1000) 58 | parser.add_argument("--save_total_limit", type=int, default=3) 59 | parser.add_argument("--resume_from_checkpoint", type=str, default=None) 60 | parser.add_argument("--run_name", type=str, default="llama-supervised-finetuned") 61 | parser.add_argument("--merge_lora", action="store_true", default=False) 62 | 63 | return parser.parse_args() 64 | 65 | 66 | def chars_token_ratio(dataset, tokenizer, nb_examples=400): 67 | """ 68 | Estimate the average number of characters per token in the dataset. 69 | """ 70 | total_characters, total_tokens = 0, 0 71 | max_token_length = 0 72 | for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): 73 | text = prepare_sample_text(example) 74 | total_characters += len(text) 75 | if tokenizer.is_fast: 76 | total_tokens += len(tokenizer(text).tokens()) 77 | if len(tokenizer(text).tokens()) > max_token_length: 78 | max_token_length = len(tokenizer(text).tokens()) 79 | else: 80 | total_tokens += len(tokenizer.tokenize(text)) 81 | if len(tokenizer.tokenize(text)) > max_token_length: 82 | max_token_length = len(tokenizer.tokenize(text)) 83 | 84 | print(f"max token length: {max_token_length}") 85 | return total_characters / total_tokens 86 | 87 | 88 | def print_trainable_parameters(model): 89 | """ 90 | Prints the number of trainable parameters in the model. 91 | """ 92 | trainable_params = 0 93 | all_param = 0 94 | for _, param in model.named_parameters(): 95 | all_param += param.numel() 96 | if param.requires_grad: 97 | trainable_params += param.numel() 98 | print( 99 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" 100 | ) 101 | 102 | # adapted to llama2 103 | def prepare_sample_text(data_point): 104 | """Prepare the text from a sample of the dataset.""" 105 | if data_point["input"]: 106 | return f"""Human: 107 | {data_point["input"]} 108 | {data_point["instruction"]} 109 | Assistant: 110 | {data_point["output"]} 111 | """ 112 | else: 113 | return f"""Human: 114 | {data_point["instruction"]} 115 | Assistant: 116 | {data_point["output"]} 117 | """ 118 | 119 | 120 | def create_datasets(tokenizer, args): 121 | train_json_path = os.path.join(args.dataset_name, "train/empathetic_dialogue_train.json") 122 | train_data = load_dataset("json", data_files=train_json_path, split="train") 123 | train_data = train_data.shuffle(seed=args.seed) 124 | 125 | val_json_path = os.path.join(args.dataset_name, "empathetic_dialogue_valid.json") 126 | valid_data = load_dataset("json", data_files=val_json_path, split="train") 127 | valid_data = valid_data.shuffle(seed=args.seed) 128 | 129 | chars_per_token = chars_token_ratio(train_data, tokenizer) 130 | print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") 131 | 132 | train_dataset = ConstantLengthDataset( 133 | tokenizer, 134 | train_data, 135 | formatting_func=prepare_sample_text, 136 | infinite=True, 137 | seq_length=args.seq_length, 138 | chars_per_token=chars_per_token, 139 | ) 140 | valid_dataset = ConstantLengthDataset( 141 | tokenizer, 142 | valid_data, 143 | formatting_func=prepare_sample_text, 144 | infinite=False, 145 | seq_length=args.seq_length, 146 | chars_per_token=chars_per_token, 147 | ) 148 | 149 | print(f"Size of the train dataset: {len(train_dataset)}") 150 | print(f"Size of the validation dataset: {len(valid_dataset)}") 151 | 152 | return train_dataset, valid_dataset 153 | 154 | 155 | def run_training(args, train_data, val_data, tokenizer=None): 156 | print("Loading the model") 157 | 158 | lora_config = LoraConfig( 159 | r=args.lora_r, 160 | lora_alpha=args.lora_alpha, 161 | lora_dropout=args.lora_dropout, 162 | target_modules=args.lora_target_modules, 163 | bias="none", 164 | task_type="CAUSAL_LM", 165 | ) 166 | 167 | train_data.start_iteration = 0 168 | 169 | print("Starting main loop") 170 | 171 | training_args = TrainingArguments( 172 | output_dir=args.output_dir, 173 | dataloader_drop_last=True, 174 | evaluation_strategy="steps", 175 | max_steps=args.max_steps, 176 | eval_steps=args.eval_freq, 177 | save_steps=args.save_freq, 178 | logging_steps=args.log_freq, 179 | save_total_limit=args.save_total_limit, 180 | per_device_train_batch_size=args.batch_size, 181 | per_device_eval_batch_size=args.batch_size, 182 | learning_rate=args.learning_rate, 183 | lr_scheduler_type=args.lr_scheduler_type, 184 | warmup_steps=args.num_warmup_steps, 185 | gradient_accumulation_steps=args.gradient_accumulation_steps, 186 | gradient_checkpointing=args.no_gradient_checkpointing, 187 | fp16=args.fp16, 188 | bf16=args.no_bf16, 189 | weight_decay=args.weight_decay, 190 | warmup_ratio=args.warmup_ratio, 191 | run_name=args.run_name, 192 | report_to="wandb", 193 | ddp_find_unused_parameters=False if int(os.environ.get("WORLD_SIZE", 1)) != 1 else None, 194 | ) 195 | 196 | model = AutoModelForCausalLM.from_pretrained( 197 | args.base_model, 198 | load_in_8bit=True, 199 | ) 200 | 201 | if args.resume_from_checkpoint: 202 | # Check the available weights and load them 203 | checkpoint_name = os.path.join( 204 | args.resume_from_checkpoint, "pytorch_model.bin" 205 | ) # Full checkpoint 206 | if not os.path.exists(checkpoint_name): 207 | checkpoint_name = os.path.join( 208 | args.resume_from_checkpoint, "adapter_model.bin" 209 | ) # only LoRA model - LoRA config above has to fit 210 | args.resume_from_checkpoint = None 211 | 212 | if os.path.exists(checkpoint_name): 213 | import torch 214 | from peft import ( 215 | get_peft_model, 216 | prepare_model_for_int8_training, 217 | set_peft_model_state_dict 218 | ) 219 | print(f"Restarting from {checkpoint_name}") 220 | model = prepare_model_for_int8_training(model) 221 | model = get_peft_model(model, lora_config) 222 | 223 | adapters_weights = torch.load(checkpoint_name) 224 | set_peft_model_state_dict(model, adapters_weights) 225 | else: 226 | print(f"Checkpoint {checkpoint_name} not found") 227 | 228 | trainer = SFTTrainer( 229 | model=model, 230 | tokenizer=tokenizer, 231 | args=training_args, 232 | train_dataset=train_data, 233 | eval_dataset=val_data, 234 | peft_config=lora_config, 235 | max_seq_length=args.seq_length, 236 | packing=True, 237 | ) 238 | 239 | print_trainable_parameters(model) 240 | 241 | print("Training...") 242 | trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) 243 | 244 | print("Saving last checkpoint of the model") 245 | final_model_path = os.path.join(args.output_dir, "final_checkpoint/") 246 | trainer.model.save_pretrained(final_model_path) 247 | 248 | if args.merge_lora: 249 | merge_llm_with_lora(args.base_model, final_model_path, args.output_dir) 250 | 251 | 252 | def main(args): 253 | if "llama" in args.base_model.lower(): 254 | tokenizer = LlamaTokenizer.from_pretrained(args.base_model) 255 | tokenizer.add_special_tokens( 256 | { 257 | "eos_token": "", 258 | "bos_token": "", 259 | "unk_token": "", 260 | "pad_token": "", 261 | } 262 | ) 263 | else: 264 | tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=False) 265 | if getattr(tokenizer, "pad_token", None) is None: 266 | tokenizer.pad_token = tokenizer.eos_token 267 | 268 | train_dataset, eval_dataset = create_datasets(tokenizer, args) 269 | run_training(args, train_dataset, eval_dataset, tokenizer) 270 | 271 | 272 | if __name__ == "__main__": 273 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 274 | 275 | args = get_args() 276 | assert args.base_model != "", "Please provide the llama model path" 277 | 278 | set_seed(args.seed) 279 | os.makedirs(args.output_dir, exist_ok=True) 280 | 281 | logging.set_verbosity_error() 282 | 283 | main(args) 284 | -------------------------------------------------------------------------------- /test_llama2_chat_cot.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM 2 | import sys 3 | import argparse 4 | import torch 5 | 6 | def main(args): 7 | prompt = "Generate the response using the pipeline of emotion, emotion stimulus, individual appraisal, strategy reason and response." 8 | 9 | device = torch.device(f"cuda:{args.gpu_id}") if torch.cuda.is_available() else torch.device("cpu") 10 | model = LlamaForCausalLM.from_pretrained(args.model_path, device_map=device, low_cpu_mem_usage=True) 11 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 12 | 13 | print("Human:") 14 | line = input() 15 | inputs = "" 16 | while line: 17 | if inputs == "": 18 | inputs = 'Human: ' + line.strip() + '\n' + prompt + '\nAssistant:' 19 | else: 20 | inputs = inputs.replace("Human: ", "") 21 | inputs = inputs.replace('\n'+prompt, "") 22 | inputs = inputs + '\nseeker: ' + line.strip() 23 | inputs = 'Human: ' + inputs + '\n' + prompt + '\nAssistant:' 24 | 25 | input_ids = tokenizer(inputs, return_tensors="pt").input_ids 26 | input_ids = input_ids.to(device) 27 | outputs = model.generate(input_ids, max_new_tokens=500, do_sample = True, top_k = 30, top_p = 0.85, temperature = 0.5, repetition_penalty=1., eos_token_id=2, bos_token_id=1, pad_token_id=0) 28 | rets = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False) 29 | response = rets[0].strip().replace(inputs, "") 30 | 31 | inputs = inputs.replace("\nAssistant:", '\nsupporter:'+ response.strip()) 32 | print("\nAssistant:" + response) 33 | print("\n------------------------------------------------\nSeeker:") 34 | line = input() 35 | 36 | if line == "clear": 37 | inputs = "" 38 | print("History cleared.") 39 | print("\n------------------------------------------------\nHuman:") 40 | line = input() 41 | if line == "exit": 42 | sys.exit() 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--model_path", type=str, required=True, help="The path of the pretrained model.") 47 | parser.add_argument("--gpu_id", type=int, required=True, help="The id of the GPU to be used.") 48 | args = parser.parse_args() 49 | 50 | main(args) 51 | -------------------------------------------------------------------------------- /test_llama2_inference_cot.py: -------------------------------------------------------------------------------- 1 | import json 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM 3 | import torch 4 | from tqdm import tqdm 5 | import os 6 | 7 | # Function to prepare data 8 | def prepare_sample_text(data_point): 9 | """Prepare the text from a sample of the dataset.""" 10 | return f"""Human: 11 | {data_point["input"]} 12 | {data_point["instruction"]} 13 | Assistant: """ 14 | 15 | # Main function 16 | def main(args): 17 | # Load model and tokenizer 18 | device = torch.device(f"cuda:{args.gpu_id}") if torch.cuda.is_available() else torch.device("cpu") 19 | model = LlamaForCausalLM.from_pretrained(args.model_path, device_map=device, low_cpu_mem_usage=True) 20 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 21 | 22 | # Read JSON file 23 | with open(args.json_path, 'r') as file: 24 | data = json.load(file) 25 | 26 | # Store inference results 27 | results = [] 28 | 29 | # Perform inference for each data point 30 | for data_point in tqdm(data): 31 | prepared_text = prepare_sample_text(data_point) 32 | input_ids = tokenizer(prepared_text, return_tensors="pt").input_ids 33 | prepared_text = tokenizer.decode(input_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) 34 | input_ids = input_ids.to(device) 35 | 36 | # Execute model inference 37 | outputs = model.generate(input_ids, max_new_tokens=500, do_sample=True, top_k=30, top_p=0.85, temperature=0.5, repetition_penalty=1., eos_token_id=2, bos_token_id=1, pad_token_id=0) 38 | rets = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False) 39 | 40 | # Extract response and remove original text 41 | full_response = rets[0].strip() 42 | response = full_response.replace(prepared_text, "").strip() 43 | 44 | # Save results 45 | results.append({'input': data_point['input'], 'label': data_point['output'], 'prediction': response, 'dialog_id': data_point['dialog_id']}) 46 | 47 | # Save results to file 48 | if "test" in args.json_path: 49 | with open(os.path.join(args.model_path, 'test_inference_results.json'), 'w') as outfile: 50 | json.dump(results, outfile, indent=4, ensure_ascii=False) 51 | else: 52 | with open(os.path.join(args.model_path, 'val_inference_results.json'), 'w') as outfile: 53 | json.dump(results, outfile, indent=4, ensure_ascii=False) 54 | 55 | # Use argparse to handle command line arguments 56 | if __name__ == "__main__": 57 | import argparse 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("--model_path", type=str, required=True, help="The path of the pretrained model.") 60 | parser.add_argument("--gpu_id", type=int, required=True, help="The id of the GPU to be used.") 61 | parser.add_argument("--json_path", type=str, required=True, help="Path to the JSON file containing the data.") 62 | 63 | args = parser.parse_args() 64 | 65 | main(args) --------------------------------------------------------------------------------