├── README.md
├── data
├── ablation_data
│ ├── em_es_ia_sr_re
│ │ ├── empathetic_dialogue_valid.json
│ │ ├── test
│ │ │ └── empathetic_dialogue_test.json
│ │ └── train
│ │ │ └── empathetic_dialogue_train.json
│ ├── em_sr_re
│ │ ├── empathetic_dialogue_valid.json
│ │ ├── test
│ │ │ └── empathetic_dialogue_test.json
│ │ └── train
│ │ │ └── empathetic_dialogue_train.json
│ ├── re
│ │ ├── empathetic_dialogue_valid.json
│ │ ├── test
│ │ │ └── empathetic_dialogue_test.json
│ │ └── train
│ │ │ └── empathetic_dialogue_train.json
│ └── sr_re
│ │ ├── empathetic_dialogue_valid.json
│ │ ├── test
│ │ └── empathetic_dialogue_test.json
│ │ └── train
│ │ └── empathetic_dialogue_train.json
├── test.json
├── train.json
└── val.json
├── merge.py
├── requirements.txt
├── scripts
├── supervised_finetune_llama2_cot.sh
├── supervised_finetune_llama2_cot_ablation.sh
├── test_llama2_chat_sft_cot.sh
└── test_llama2_inference_cot.sh
├── supervised_finetuning_cot.py
├── test_llama2_chat_cot.py
└── test_llama2_inference_cot.py
/README.md:
--------------------------------------------------------------------------------
1 | # ESCoT: Towards Interpretable Emotional Support Dialogue Systems
2 |
3 |
4 |
5 | This is the repository of our ACL 2024 main paper "[**ESCoT: Towards Interpretable Emotional Support Dialogue Systems**](https://aclanthology.org/2024.acl-long.723/)".
6 |
7 | ## ESD-CoT Dataset
8 |
9 | Our ESD-CoT dataset is organized under the `data` folder and is split into three JSON files: `train`, `val`, and `test`. Each file contains samples structured as follows:
10 |
11 | ```json
12 | {
13 | "id": ,
14 | "original_data": {
15 | "dialog": [
16 | {
17 | "speaker": "seeker",
18 | "content": "Hi, I'm having a really hard time managing my schoolwork and extracurricular activities. I feel like there's just not enough hours in the day."
19 | },
20 | ...
21 | {
22 | "speaker": "seeker",
23 | "content": "Yeah, I can try that."
24 | }
25 | ],
26 | "strategy": "Providing Suggestions",
27 | "response": "Great, and let's touch base next week to see if the list has been helpful. In the meantime, have you considered talking to your teacher or a guidance counselor about feeling overwhelmed?"
28 | },
29 | "cot_data": {
30 | "emotion": "The seeker feels overwhelmed and stretched thin.",
31 | "emotion_stimuli": "The seeker is struggling to manage schoolwork...",
32 | "individual_appraisal": "The seeker thinks they are not able to do anything well...",
33 | "recognized_strategy": "Providing Suggestions",
34 | "strategy_reason": "To address the seeker's feeling of being overwhelmed and..."
35 | }
36 | }
37 | ```
38 | Additionally, we provide instructional format training data in the `data/ablation_data` folder.
39 |
40 | ## Model Training
41 |
42 | ### Download the pretrained models
43 | Download the [**LLAMA2-7B-CHAT**](https://huggingface.co/meta-llama/Llama-2-7b-hf) model.
44 |
45 | The training of LLAMA2-CHAT model is based on the [**SFT trainer of Transformer Reinforcement Learning**](https://github.com/huggingface/trl).
46 |
47 | ### Train Model
48 | Run bash `scripts/supervised_finetune_llama2_cot.sh` to train your model.
49 |
50 | Run bash `scripts/supervised_finetune_llama2_cot_ablation.sh` for Ablation Study model training.
51 |
52 | ### Test Model
53 | Run bash `scripts/test_llama2_chat_sft_cot.sh` or `scripts/test_llama2_inference_cot.sh`.
54 |
55 | ## Cite
56 | If you use our codes or your research is related to our work, please kindly cite our paper:
57 | ```bib
58 | @inproceedings{zhang-etal-2024-escot,
59 | title = "{ESC}o{T}: Towards Interpretable Emotional Support Dialogue Systems",
60 | author = "Zhang, Tenggan and Zhang, Xinjie and Zhao, Jinming and Zhou, Li and Jin, Qin",
61 | booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
62 | year = "2024"
63 | }
64 | ```
65 |
66 | Please contact zhangxinjie827@ruc.edu.cn for any problems.
--------------------------------------------------------------------------------
/merge.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from peft import PeftModel
5 | from transformers import (
6 | AutoModelForCausalLM,
7 | AutoTokenizer,
8 | LlamaTokenizer
9 | )
10 |
11 | DEFAULT_PAD_TOKEN = "[PAD]"
12 | DEFAULT_EOS_TOKEN = ""
13 | DEFAULT_BOS_TOKEN = ""
14 | DEFAULT_UNK_TOKEN = ""
15 |
16 |
17 | def merge_llm_with_lora(base_model_name, adapter_model_name, output_name, push_to_hub=False):
18 | base_model = AutoModelForCausalLM.from_pretrained(
19 | base_model_name,
20 | return_dict=True,
21 | torch_dtype=torch.bfloat16
22 | )
23 |
24 | model = PeftModel.from_pretrained(base_model, adapter_model_name)
25 | model = model.merge_and_unload()
26 |
27 | if "decapoda" in base_model_name.lower():
28 | tokenizer = LlamaTokenizer.from_pretrained(base_model_name)
29 | tokenizer.add_special_tokens(
30 | {
31 | "eos_token": DEFAULT_EOS_TOKEN,
32 | "bos_token": DEFAULT_BOS_TOKEN,
33 | "unk_token": DEFAULT_UNK_TOKEN,
34 | "pad_token": DEFAULT_PAD_TOKEN,
35 | }
36 | )
37 | else:
38 | tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_fast=False)
39 |
40 | if push_to_hub:
41 | print(f"Saving to hub ...")
42 | model.push_to_hub(f"{base_model_name}-merged", use_temp_dir=False, private=True)
43 | tokenizer.push_to_hub(f"{base_model_name}-merged", use_temp_dir=False, private=True)
44 | else:
45 | output_name = os.path.join(output_name, "final_checkpoint-merged")
46 | model.save_pretrained(output_name)
47 | tokenizer.save_pretrained(output_name)
48 | print(f"Model saved to {output_name}")
49 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.4.0
2 | datasets>=1.17.0
3 | transformers>=4.28.0
4 | accelerate
5 | evaluate
6 | # git+https://github.com/huggingface/peft.git
7 | # git+https://github.com/lvwerra/trl.git
8 | peft
9 | trl
10 | tqdm
11 | sentencepiece
12 | bitsandbytes
13 | wandb
--------------------------------------------------------------------------------
/scripts/supervised_finetune_llama2_cot.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | export CUDA_VISIBLE_DEVICES=0,1,2,3
4 |
5 | run_name='supervised_llama2_cot'
6 |
7 | torchrun --nnodes 1 --nproc_per_node 4 supervised_finetuning_cot.py \
8 | --base_model '/datassd2/ztg/pretrained_models/Llama-2-7b-chat-hf' \
9 | --dataset_name './data/ablation_data/em_es_ia_sr_re' \
10 | --lr_scheduler_type 'cosine' \
11 | --learning_rate 1e-5 \
12 | --max_steps 10000 \
13 | --save_freq 500 \
14 | --seq_length 2048 \
15 | --batch_size 8 \
16 | --run_name $run_name \
17 | --output_dir './checkpoints/cot/'$run_name
18 |
--------------------------------------------------------------------------------
/scripts/supervised_finetune_llama2_cot_ablation.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | export CUDA_VISIBLE_DEVICES=0,1,2,3
4 |
5 | base_model='./pretrained_models/Llama-2-7b-chat-hf'
6 | dataset_base='./data/ablation_data/'
7 | output_base='./checkpoints/cot/'
8 | lr_scheduler_type='cosine'
9 | learning_rate=5e-5
10 | max_steps=380
11 | save_freq=38
12 | eval_freq=380
13 | seq_length=2048
14 | batch_size=8
15 |
16 | settings=("em_es_ia_sr_re" "em_sr_re" "sr_re" "re")
17 |
18 | for setting in "${settings[@]}"
19 | do
20 | run_name="supervised_llama2_cot_ablation_${setting}"
21 |
22 | dataset_name="${dataset_base}${setting}"
23 |
24 | output_dir="${output_base}${run_name}"
25 |
26 | echo "Running setting: $setting"
27 | echo "Dataset path: $dataset_name"
28 | echo "Output path: $output_dir"
29 |
30 | torchrun --nnodes 1 --nproc_per_node 4 supervised_finetuning_cot.py \
31 | --base_model "$base_model" \
32 | --dataset_name "$dataset_name" \
33 | --lr_scheduler_type "$lr_scheduler_type" \
34 | --learning_rate "$learning_rate" \
35 | --max_steps "$max_steps" \
36 | --save_freq "$save_freq" \
37 | --eval_freq "$eval_freq" \
38 | --seq_length "$seq_length" \
39 | --batch_size "$batch_size" \
40 | --run_name "$run_name" \
41 | --output_dir "$output_dir" \
42 | --save_total_limit 100 \
43 | --seed 1104
44 | done
45 |
--------------------------------------------------------------------------------
/scripts/test_llama2_chat_sft_cot.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Define the model path and GPU id as variables
4 | MODEL_PATH="change_to_checkpoint_path"
5 | GPU_ID=0
6 |
7 | # Execute the Python script with the model path and GPU id as arguments
8 | python test_llama2_chat_cot.py --model_path $MODEL_PATH --gpu_id $GPU_ID
9 |
--------------------------------------------------------------------------------
/scripts/test_llama2_inference_cot.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Define the model path and GPU id as variables
4 | MODEL_PATH="change_to_checkpoint_path"
5 | JSON_PATH="./data/ablation_data/em_es_ia_sr_re/empathetic_dialogue_valid.json"
6 | GPU_ID=0
7 |
8 | # Execute the Python script with the model path and GPU id as arguments
9 | python test_llama2_inference_cot.py --model_path $MODEL_PATH --gpu_id $GPU_ID --json_path $JSON_PATH
10 |
--------------------------------------------------------------------------------
/supervised_finetuning_cot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | from tqdm import tqdm
5 | from accelerate import Accelerator
6 | from datasets import load_dataset
7 | from peft import LoraConfig
8 | from transformers import (
9 | AutoModelForCausalLM,
10 | AutoTokenizer,
11 | LlamaTokenizer,
12 | TrainingArguments,
13 | logging,
14 | set_seed
15 | )
16 | from trl import SFTTrainer
17 | from trl.trainer import ConstantLengthDataset
18 |
19 | from utils.merge import merge_llm_with_lora
20 |
21 |
22 | def get_args():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--base_model", type=str, default="")
25 | parser.add_argument("--dataset_name", type=str, default="./data/alpaca_gpt4_data.json")
26 | parser.add_argument("--split", type=str, default="train")
27 | parser.add_argument("--size_valid_set", type=int, default=4000)
28 | parser.add_argument("--streaming", action="store_true", default=False)
29 | parser.add_argument("--shuffle_buffer", type=int, default=5000)
30 |
31 | parser.add_argument("--seq_length", type=int, default=1024)
32 | parser.add_argument("--max_steps", type=int, default=10000)
33 | parser.add_argument("--batch_size", type=int, default=16)
34 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
35 | parser.add_argument("--eos_token_id", type=int, default=49152)
36 |
37 | parser.add_argument("--lora_r", type=int, default=16)
38 | parser.add_argument("--lora_alpha", type=int, default=32)
39 | parser.add_argument("--lora_dropout", type=float, default=0.05)
40 | parser.add_argument("--lora_target_modules", type=str, default=None)
41 |
42 | parser.add_argument("--learning_rate", type=float, default=1e-4)
43 | parser.add_argument("--lr_scheduler_type", type=str, default="linear")
44 | parser.add_argument("--num_warmup_steps", type=int, default=100)
45 | parser.add_argument("--weight_decay", type=float, default=0.05)
46 | parser.add_argument("--warmup_ratio", type=float, default=0.)
47 |
48 | parser.add_argument("--local_rank", type=int, default=0)
49 | parser.add_argument("--fp16", action="store_true", default=False)
50 | parser.add_argument("--no_bf16", action="store_false", default=True)
51 | parser.add_argument("--no_gradient_checkpointing", action="store_false", default=True)
52 | parser.add_argument("--seed", type=int, default=1103)
53 | parser.add_argument("--num_workers", type=int, default=None)
54 | parser.add_argument("--output_dir", type=str, default="./checkpoints/supervised_llama/")
55 | parser.add_argument("--log_freq", type=int, default=1)
56 | parser.add_argument("--eval_freq", type=int, default=1000)
57 | parser.add_argument("--save_freq", type=int, default=1000)
58 | parser.add_argument("--save_total_limit", type=int, default=3)
59 | parser.add_argument("--resume_from_checkpoint", type=str, default=None)
60 | parser.add_argument("--run_name", type=str, default="llama-supervised-finetuned")
61 | parser.add_argument("--merge_lora", action="store_true", default=False)
62 |
63 | return parser.parse_args()
64 |
65 |
66 | def chars_token_ratio(dataset, tokenizer, nb_examples=400):
67 | """
68 | Estimate the average number of characters per token in the dataset.
69 | """
70 | total_characters, total_tokens = 0, 0
71 | max_token_length = 0
72 | for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
73 | text = prepare_sample_text(example)
74 | total_characters += len(text)
75 | if tokenizer.is_fast:
76 | total_tokens += len(tokenizer(text).tokens())
77 | if len(tokenizer(text).tokens()) > max_token_length:
78 | max_token_length = len(tokenizer(text).tokens())
79 | else:
80 | total_tokens += len(tokenizer.tokenize(text))
81 | if len(tokenizer.tokenize(text)) > max_token_length:
82 | max_token_length = len(tokenizer.tokenize(text))
83 |
84 | print(f"max token length: {max_token_length}")
85 | return total_characters / total_tokens
86 |
87 |
88 | def print_trainable_parameters(model):
89 | """
90 | Prints the number of trainable parameters in the model.
91 | """
92 | trainable_params = 0
93 | all_param = 0
94 | for _, param in model.named_parameters():
95 | all_param += param.numel()
96 | if param.requires_grad:
97 | trainable_params += param.numel()
98 | print(
99 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
100 | )
101 |
102 | # adapted to llama2
103 | def prepare_sample_text(data_point):
104 | """Prepare the text from a sample of the dataset."""
105 | if data_point["input"]:
106 | return f"""Human:
107 | {data_point["input"]}
108 | {data_point["instruction"]}
109 | Assistant:
110 | {data_point["output"]}
111 | """
112 | else:
113 | return f"""Human:
114 | {data_point["instruction"]}
115 | Assistant:
116 | {data_point["output"]}
117 | """
118 |
119 |
120 | def create_datasets(tokenizer, args):
121 | train_json_path = os.path.join(args.dataset_name, "train/empathetic_dialogue_train.json")
122 | train_data = load_dataset("json", data_files=train_json_path, split="train")
123 | train_data = train_data.shuffle(seed=args.seed)
124 |
125 | val_json_path = os.path.join(args.dataset_name, "empathetic_dialogue_valid.json")
126 | valid_data = load_dataset("json", data_files=val_json_path, split="train")
127 | valid_data = valid_data.shuffle(seed=args.seed)
128 |
129 | chars_per_token = chars_token_ratio(train_data, tokenizer)
130 | print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
131 |
132 | train_dataset = ConstantLengthDataset(
133 | tokenizer,
134 | train_data,
135 | formatting_func=prepare_sample_text,
136 | infinite=True,
137 | seq_length=args.seq_length,
138 | chars_per_token=chars_per_token,
139 | )
140 | valid_dataset = ConstantLengthDataset(
141 | tokenizer,
142 | valid_data,
143 | formatting_func=prepare_sample_text,
144 | infinite=False,
145 | seq_length=args.seq_length,
146 | chars_per_token=chars_per_token,
147 | )
148 |
149 | print(f"Size of the train dataset: {len(train_dataset)}")
150 | print(f"Size of the validation dataset: {len(valid_dataset)}")
151 |
152 | return train_dataset, valid_dataset
153 |
154 |
155 | def run_training(args, train_data, val_data, tokenizer=None):
156 | print("Loading the model")
157 |
158 | lora_config = LoraConfig(
159 | r=args.lora_r,
160 | lora_alpha=args.lora_alpha,
161 | lora_dropout=args.lora_dropout,
162 | target_modules=args.lora_target_modules,
163 | bias="none",
164 | task_type="CAUSAL_LM",
165 | )
166 |
167 | train_data.start_iteration = 0
168 |
169 | print("Starting main loop")
170 |
171 | training_args = TrainingArguments(
172 | output_dir=args.output_dir,
173 | dataloader_drop_last=True,
174 | evaluation_strategy="steps",
175 | max_steps=args.max_steps,
176 | eval_steps=args.eval_freq,
177 | save_steps=args.save_freq,
178 | logging_steps=args.log_freq,
179 | save_total_limit=args.save_total_limit,
180 | per_device_train_batch_size=args.batch_size,
181 | per_device_eval_batch_size=args.batch_size,
182 | learning_rate=args.learning_rate,
183 | lr_scheduler_type=args.lr_scheduler_type,
184 | warmup_steps=args.num_warmup_steps,
185 | gradient_accumulation_steps=args.gradient_accumulation_steps,
186 | gradient_checkpointing=args.no_gradient_checkpointing,
187 | fp16=args.fp16,
188 | bf16=args.no_bf16,
189 | weight_decay=args.weight_decay,
190 | warmup_ratio=args.warmup_ratio,
191 | run_name=args.run_name,
192 | report_to="wandb",
193 | ddp_find_unused_parameters=False if int(os.environ.get("WORLD_SIZE", 1)) != 1 else None,
194 | )
195 |
196 | model = AutoModelForCausalLM.from_pretrained(
197 | args.base_model,
198 | load_in_8bit=True,
199 | )
200 |
201 | if args.resume_from_checkpoint:
202 | # Check the available weights and load them
203 | checkpoint_name = os.path.join(
204 | args.resume_from_checkpoint, "pytorch_model.bin"
205 | ) # Full checkpoint
206 | if not os.path.exists(checkpoint_name):
207 | checkpoint_name = os.path.join(
208 | args.resume_from_checkpoint, "adapter_model.bin"
209 | ) # only LoRA model - LoRA config above has to fit
210 | args.resume_from_checkpoint = None
211 |
212 | if os.path.exists(checkpoint_name):
213 | import torch
214 | from peft import (
215 | get_peft_model,
216 | prepare_model_for_int8_training,
217 | set_peft_model_state_dict
218 | )
219 | print(f"Restarting from {checkpoint_name}")
220 | model = prepare_model_for_int8_training(model)
221 | model = get_peft_model(model, lora_config)
222 |
223 | adapters_weights = torch.load(checkpoint_name)
224 | set_peft_model_state_dict(model, adapters_weights)
225 | else:
226 | print(f"Checkpoint {checkpoint_name} not found")
227 |
228 | trainer = SFTTrainer(
229 | model=model,
230 | tokenizer=tokenizer,
231 | args=training_args,
232 | train_dataset=train_data,
233 | eval_dataset=val_data,
234 | peft_config=lora_config,
235 | max_seq_length=args.seq_length,
236 | packing=True,
237 | )
238 |
239 | print_trainable_parameters(model)
240 |
241 | print("Training...")
242 | trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
243 |
244 | print("Saving last checkpoint of the model")
245 | final_model_path = os.path.join(args.output_dir, "final_checkpoint/")
246 | trainer.model.save_pretrained(final_model_path)
247 |
248 | if args.merge_lora:
249 | merge_llm_with_lora(args.base_model, final_model_path, args.output_dir)
250 |
251 |
252 | def main(args):
253 | if "llama" in args.base_model.lower():
254 | tokenizer = LlamaTokenizer.from_pretrained(args.base_model)
255 | tokenizer.add_special_tokens(
256 | {
257 | "eos_token": "",
258 | "bos_token": "",
259 | "unk_token": "",
260 | "pad_token": "",
261 | }
262 | )
263 | else:
264 | tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=False)
265 | if getattr(tokenizer, "pad_token", None) is None:
266 | tokenizer.pad_token = tokenizer.eos_token
267 |
268 | train_dataset, eval_dataset = create_datasets(tokenizer, args)
269 | run_training(args, train_dataset, eval_dataset, tokenizer)
270 |
271 |
272 | if __name__ == "__main__":
273 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
274 |
275 | args = get_args()
276 | assert args.base_model != "", "Please provide the llama model path"
277 |
278 | set_seed(args.seed)
279 | os.makedirs(args.output_dir, exist_ok=True)
280 |
281 | logging.set_verbosity_error()
282 |
283 | main(args)
284 |
--------------------------------------------------------------------------------
/test_llama2_chat_cot.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM
2 | import sys
3 | import argparse
4 | import torch
5 |
6 | def main(args):
7 | prompt = "Generate the response using the pipeline of emotion, emotion stimulus, individual appraisal, strategy reason and response."
8 |
9 | device = torch.device(f"cuda:{args.gpu_id}") if torch.cuda.is_available() else torch.device("cpu")
10 | model = LlamaForCausalLM.from_pretrained(args.model_path, device_map=device, low_cpu_mem_usage=True)
11 | tokenizer = AutoTokenizer.from_pretrained(args.model_path)
12 |
13 | print("Human:")
14 | line = input()
15 | inputs = ""
16 | while line:
17 | if inputs == "":
18 | inputs = 'Human: ' + line.strip() + '\n' + prompt + '\nAssistant:'
19 | else:
20 | inputs = inputs.replace("Human: ", "")
21 | inputs = inputs.replace('\n'+prompt, "")
22 | inputs = inputs + '\nseeker: ' + line.strip()
23 | inputs = 'Human: ' + inputs + '\n' + prompt + '\nAssistant:'
24 |
25 | input_ids = tokenizer(inputs, return_tensors="pt").input_ids
26 | input_ids = input_ids.to(device)
27 | outputs = model.generate(input_ids, max_new_tokens=500, do_sample = True, top_k = 30, top_p = 0.85, temperature = 0.5, repetition_penalty=1., eos_token_id=2, bos_token_id=1, pad_token_id=0)
28 | rets = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)
29 | response = rets[0].strip().replace(inputs, "")
30 |
31 | inputs = inputs.replace("\nAssistant:", '\nsupporter:'+ response.strip())
32 | print("\nAssistant:" + response)
33 | print("\n------------------------------------------------\nSeeker:")
34 | line = input()
35 |
36 | if line == "clear":
37 | inputs = ""
38 | print("History cleared.")
39 | print("\n------------------------------------------------\nHuman:")
40 | line = input()
41 | if line == "exit":
42 | sys.exit()
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--model_path", type=str, required=True, help="The path of the pretrained model.")
47 | parser.add_argument("--gpu_id", type=int, required=True, help="The id of the GPU to be used.")
48 | args = parser.parse_args()
49 |
50 | main(args)
51 |
--------------------------------------------------------------------------------
/test_llama2_inference_cot.py:
--------------------------------------------------------------------------------
1 | import json
2 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM
3 | import torch
4 | from tqdm import tqdm
5 | import os
6 |
7 | # Function to prepare data
8 | def prepare_sample_text(data_point):
9 | """Prepare the text from a sample of the dataset."""
10 | return f"""Human:
11 | {data_point["input"]}
12 | {data_point["instruction"]}
13 | Assistant: """
14 |
15 | # Main function
16 | def main(args):
17 | # Load model and tokenizer
18 | device = torch.device(f"cuda:{args.gpu_id}") if torch.cuda.is_available() else torch.device("cpu")
19 | model = LlamaForCausalLM.from_pretrained(args.model_path, device_map=device, low_cpu_mem_usage=True)
20 | tokenizer = AutoTokenizer.from_pretrained(args.model_path)
21 |
22 | # Read JSON file
23 | with open(args.json_path, 'r') as file:
24 | data = json.load(file)
25 |
26 | # Store inference results
27 | results = []
28 |
29 | # Perform inference for each data point
30 | for data_point in tqdm(data):
31 | prepared_text = prepare_sample_text(data_point)
32 | input_ids = tokenizer(prepared_text, return_tensors="pt").input_ids
33 | prepared_text = tokenizer.decode(input_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
34 | input_ids = input_ids.to(device)
35 |
36 | # Execute model inference
37 | outputs = model.generate(input_ids, max_new_tokens=500, do_sample=True, top_k=30, top_p=0.85, temperature=0.5, repetition_penalty=1., eos_token_id=2, bos_token_id=1, pad_token_id=0)
38 | rets = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)
39 |
40 | # Extract response and remove original text
41 | full_response = rets[0].strip()
42 | response = full_response.replace(prepared_text, "").strip()
43 |
44 | # Save results
45 | results.append({'input': data_point['input'], 'label': data_point['output'], 'prediction': response, 'dialog_id': data_point['dialog_id']})
46 |
47 | # Save results to file
48 | if "test" in args.json_path:
49 | with open(os.path.join(args.model_path, 'test_inference_results.json'), 'w') as outfile:
50 | json.dump(results, outfile, indent=4, ensure_ascii=False)
51 | else:
52 | with open(os.path.join(args.model_path, 'val_inference_results.json'), 'w') as outfile:
53 | json.dump(results, outfile, indent=4, ensure_ascii=False)
54 |
55 | # Use argparse to handle command line arguments
56 | if __name__ == "__main__":
57 | import argparse
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument("--model_path", type=str, required=True, help="The path of the pretrained model.")
60 | parser.add_argument("--gpu_id", type=int, required=True, help="The id of the GPU to be used.")
61 | parser.add_argument("--json_path", type=str, required=True, help="Path to the JSON file containing the data.")
62 |
63 | args = parser.parse_args()
64 |
65 | main(args)
--------------------------------------------------------------------------------