├── .gitignore ├── data └── README.md ├── .gitmodules ├── requirements.txt ├── .vscode ├── settings.json └── launch.json ├── scripts ├── infer.sh ├── train_llm.sh └── infer_bioserc_bertbased.ipynb ├── LICENSE ├── src ├── reformat_data_ft_llm.py ├── ft_llm.py └── llm_bio_extract.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | .history/ 3 | finetuned_llm/ 4 | **.log 5 | **/__pycache__/ 6 | **events.out.tfevents.* -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | You can find it here (our previous work): https://github.com/yingjie7/per_erc/blob/main/data.zip 2 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "src/baseline"] 2 | path = src/baseline 3 | url = https://github.com/yingjie7/per_erc.git 4 | branch = AccumWR 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning 2 | torch 3 | huggingface 4 | lightning 5 | accelerate 6 | bitsandbytes 7 | datasets 8 | peft 9 | trl 10 | scikit-learn 11 | flash-attn 12 | tensorboard 13 | tensorboardX -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.exclude": { 3 | "**/.git": true, 4 | "**/.svn": true, 5 | "**/.hg": true, 6 | "**/CVS": true, 7 | "**/.DS_Store": true, 8 | "**/Thumbs.db": true, 9 | "**/env*": true, 10 | } 11 | } -------------------------------------------------------------------------------- /scripts/infer.sh: -------------------------------------------------------------------------------- 1 | conda activate env_py38/ && \ 2 | python src/ft_llm.py --do_eval_dev --do_eval_test \ 3 | --base_model_id "meta-llama/Llama-2-7b-hf" --ft_model_id \ 4 | "debug-check" --lr_scheduler "linear" --lr "3e-4" \ 5 | --lora_r 32 --max_steps -1 --epoch 3 --kshot 0 --window 5 \ 6 | --data_name "meld" --prompting_type "spdescV2" --extract_prompting_llm_id "Llama-2-70b-chat-hf" \ 7 | --re_gen_data --seed 45 --max_seq_len 2048 --eval_delay 600 \ 8 | --data_folder "./data" \ 9 | --ft_model_path "./finetuned_llm/meld_Llama-2-13b-hf_ep3_lrs-linear3e-4_0shot_r32_w5_spdescV2_devop_noise_s2048_R46_rc/" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 xuejieying 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // 使用 IntelliSense 了解相关属性。 3 | // 悬停以查看现有属性的描述。 4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "prac", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${workspaceFolder}/src/reformat_data_ft_llm.py", 12 | "console": "integratedTerminal", 13 | "justMyCode": true 14 | }, 15 | { 16 | "name": "ft_llm", 17 | "type": "python", 18 | "request": "launch", 19 | "program": "${workspaceFolder}/src/ft_llm.py", 20 | "cwd": "${workspaceFolder}/", 21 | "args": [ 22 | "--do_eval_dev", 23 | "--do_eval_test", 24 | "--do_train", 25 | "--base_model_id", "meta-llama/Llama-2-7b-hf", 26 | "--ft_model_id", 27 | "debug", 28 | "--lr_scheduler", "linear", 29 | "--lr", "3e-4", 30 | "--epoch","3", 31 | "--lora_r", "32", 32 | "--kshot", "0" , 33 | "--window", "5" , 34 | "--data_name", "iemocap" , 35 | "--prompting_type","fewshot-similar-default", 36 | "--extract_prompting_llm_id","Llama-2-70b-chat-hf", 37 | "--re_gen_data" 38 | ], 39 | "console": "integratedTerminal", 40 | "justMyCode": true 41 | } 42 | ] 43 | } -------------------------------------------------------------------------------- /scripts/train_llm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | # 3 | # Job Script for VPCC , JAIST 4 | # 2018.2.25 5 | 6 | #PBS -N erc-llm 7 | #PBS -j oe 8 | #PBS -q GPU-1 9 | #PBS -o pbs_infer-sp.log 10 | #PBS -e infer-sp.err.log 11 | 12 | # source ~/.bashrc 13 | 14 | # conda activate env_llm/ 15 | 16 | EP=3 17 | LR_SCHEDULER="linear" 18 | LR=3e-4 19 | LORA_R=32 20 | TOPK=0 21 | WINDOW=5 22 | PROMPT_TYPE="spdescV2" # spdescV2 | default 23 | MODEL_ID="meta-llama/Llama-2-7b-hf" # "meta-llama/Llama-2-7b-hf" 24 | DATANAME="meld" # iemocap | meld | emorynlp 25 | EXTRACT_PROMTING_LLM_ID="Llama-2-70b-chat-hf" # Meta-Llama-3-8B-Instruct 26 | MAX_SEQ_LEN=2048 # 1024 for IEMOCAP, EMORYNLP; 2048 for MELD 27 | MAX_STEPS=-1 28 | EVAL_DELAY=600 29 | 30 | 31 | IFS='/' read -ra ADDR <<< "$MODEL_ID" 32 | MODEL_ID_0=${ADDR[1]} 33 | 34 | for seed in 42 43 44 45 46 ; 35 | do 36 | python ./src/ft_llm.py --do_eval_dev --do_eval_test --do_train \ 37 | --base_model_id $MODEL_ID \ 38 | --ft_model_id ${DATANAME}_${MODEL_ID_0}_ep${EP}_step${MAX_STEPS}_lrs-${LR_SCHEDULER}${LR}_${TOPK}shot_r${LORA_R}_w${WINDOW}_${PROMPT_TYPE}_seed${seed}_L${MAX_SEQ_LEN}_llmdesc${EXTRACT_PROMTING_LLM_ID}_ED${EVAL_DELAY} \ 39 | --lr_scheduler $LR_SCHEDULER --lr $LR --lora_r $LORA_R --max_steps $MAX_STEPS --epoch ${EP} \ 40 | --kshot $TOPK --window $WINDOW --data_name $DATANAME --prompting_type ${PROMPT_TYPE} --extract_prompting_llm_id $EXTRACT_PROMTING_LLM_ID \ 41 | --re_gen_data --seed $seed --max_seq_len $MAX_SEQ_LEN --eval_delay $EVAL_DELAY --data_folder ./data/ 42 | 43 | done 44 | 45 | wait 46 | 47 | -------------------------------------------------------------------------------- /src/reformat_data_ft_llm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | 4 | data_name_pattern = 'train' 5 | 6 | def get_speaker_name(s_id, gender, data_name): 7 | if data_name == "iemocap": 8 | # iemocap: label index mapping = {'hap':0, 'sad':1, 'neu':2, 'ang':3, 'exc':4, 'fru':5} 9 | speaker = { 10 | "Ses01": {"F": "Mary", "M": "James"}, 11 | "Ses02": {"F": "Patricia", "M": "John"}, 12 | "Ses03": {"F": "Jennifer", "M": "Robert"}, 13 | "Ses04": {"F": "Linda", "M": "Michael"}, 14 | "Ses05": {"F": "Elizabeth", "M": "William"}, 15 | } 16 | s_id_first_part = s_id[:5] 17 | return speaker[s_id_first_part][gender].upper() 18 | elif data_name in ['meld', "emorynlp"]: 19 | # emorynlp: label index mapping = {'Joyful': 0, 'Mad': 1, 'Peaceful': 2, 'Neutral': 3, 'Sad': 4, 'Powerful': 5, 'Scared': 6} 20 | # meld: label index mapping = {'neutral': 0, 'surprise': 1, 'fear': 2, 'sadness': 3, 'joy': 4, 'disgust': 5, 'anger':6} 21 | gender_idx = gender.index(1) 22 | return f"SPEAKER_{gender_idx}" 23 | elif data_name=='dailydialog': 24 | # dailydialog: {'no_emotion': 0, 'happiness': 1, 'sadness': 2, 'surprise': 3, 'anger': 4, 'fear': 5, 'disgust':6} 25 | return f"SPEAKER_{gender}" 26 | 27 | 28 | def flatten_conversation_mixed_by_surrounding(conv, around_window, s_id, genders, data_name): 29 | new_data = [] 30 | for i, cur_sent in enumerate(conv): 31 | tmp_window = [] 32 | for j in range(max(0, i-around_window), min(len(conv), i+around_window+1)): 33 | tmp_window.append(f" {get_speaker_name(s_id, genders[j], data_name=data_name)}: {conv[j]}") 34 | 35 | new_data.append(tmp_window) 36 | return new_data 37 | 38 | def get_label_map(data_name): 39 | all_data_label_map = { 40 | "iemocap": {0:'happy',1:'sad',2:'neutral',3:'angry',4:'excited',5:'frustrated'}, 41 | "emorynlp": ['Joyful', 'Mad', 'Peaceful', 'Neutral', 'Sad', 'Powerful', 'Scared'], 42 | "meld": ['neutral', 'surprise', 'fear', 'sadness', 'joy', 'disgust', 'anger'], 43 | "dailydialog": ['no_emotion', 'happiness', 'sadness', 'surprise', 'anger', 'fear', 'disgust'] 44 | } 45 | return all_data_label_map[data_name] 46 | 47 | def preprocess_desc_speaker(str_in): 48 | str_in = str_in.split("")[0].replace("", "").replace("\n", " ") 49 | str_out = re.sub(r" {2,}", " ", str_in) 50 | return str_out 51 | 52 | def gen_default_prompting_messages(data_name, conv, around_window, s_id, desc_speaker_data=None): 53 | new_conv = [] 54 | samples = [] 55 | for i,sent in enumerate(conv['sentences']): 56 | new_sent_gender = conv['genders'][i] 57 | sent_name = get_speaker_name(s_id,new_sent_gender, data_name) 58 | new_sent = f'{sent_name}: {sent}' 59 | new_conv.append(new_sent) 60 | conv_str = "\n".join(new_conv) 61 | 62 | flatten_conv = flatten_conversation_mixed_by_surrounding(conv['sentences'], around_window, s_id, conv['genders'], data_name) 63 | 64 | for i, sent in enumerate(new_conv): 65 | system_msg = f'### You are an expert at analyzing the emotion of utterances among speakers in a conversation.' 66 | conv_str = "\n".join(flatten_conv[i]) 67 | local_context_msg = f"\n### Given the following conversation as a context \n{conv_str}" 68 | speaker_name = get_speaker_name(s_id, conv["genders"][i], data_name) 69 | q_msg = f'Based on above conversation, which emotional label of {speaker_name} in the utterance \"{conv["sentences"][i]}\".' 70 | 71 | label_msg = get_label_map(data_name)[conv['labels'][i]] 72 | 73 | samples.append({ 74 | "messages": [ 75 | {'role': "system", 'content': system_msg + local_context_msg}, 76 | {'role': "user", 'content': q_msg}, 77 | {'role': "assistant", 'content': label_msg}, 78 | ] 79 | }) 80 | return samples 81 | 82 | def gen_spdescV2_prompting_messages(data_name, conv, around_window, s_id, desc_speaker_data): 83 | new_conv = [] 84 | for i,sent in enumerate(conv['sentences']): 85 | new_sent_gender = conv['genders'][i] 86 | sent_name = get_speaker_name(s_id,new_sent_gender, data_name) 87 | new_sent = f'{sent_name}: {sent}' 88 | new_conv.append(new_sent) 89 | conv_str = "\n".join(new_conv) 90 | 91 | flatten_conv = flatten_conversation_mixed_by_surrounding(conv['sentences'], around_window, s_id, conv['genders'], data_name) 92 | 93 | samples = [] 94 | for i, sent in enumerate(new_conv): 95 | system_msg = f'### You are an expert at analyzing the emotion of utterances among speakers in a conversation.' 96 | speaker_name = get_speaker_name(s_id, conv["genders"][i], data_name) 97 | 98 | desc_str = desc_speaker_data[s_id][i].replace("\n", " ") 99 | 100 | 101 | 102 | desc_msg = f'\n### Given the characteristic of this speaker, {speaker_name}: \n{desc_str}' 103 | 104 | conv_str = "\n".join(flatten_conv[i]) 105 | local_context_msg = f"\n### Given the following conversation as a context \n{conv_str}" 106 | 107 | q_msg = f'Based on above conversation and characteristic of the speakers, which emotional label of {speaker_name} in the utterance \"{conv["sentences"][i]}\".' 108 | label_msg = get_label_map(data_name)[conv['labels'][i]] 109 | 110 | samples.append({ 111 | "messages": [ 112 | {'role': "system", 'content': system_msg + desc_msg + local_context_msg}, 113 | {'role': "user", 'content': q_msg}, 114 | {'role': "assistant", 'content': label_msg}, 115 | ] 116 | }) 117 | 118 | return samples 119 | 120 | def process(paths_folder_preprocessed_data, args): 121 | 122 | process_kwargs = {} 123 | for path_folder_preprocessed_data in paths_folder_preprocessed_data: 124 | 125 | d_type = 'train' if '.train.' in path_folder_preprocessed_data else \ 126 | 'valid' if '.valid.' in path_folder_preprocessed_data else \ 127 | 'test' if '.test.' in path_folder_preprocessed_data else None 128 | 129 | folder_data = args.data_folder 130 | around_window = args.window 131 | data_name = args.data_name 132 | path_data_out = path_folder_preprocessed_data 133 | prompting_type = args.prompting_type 134 | extract_prompting_llm_id = args.extract_prompting_llm_id 135 | 136 | raw_data = f'{folder_data}/{data_name}.{d_type}.json' 137 | org_data = json.load(open(raw_data)) # ; org_data = dict([(k,v) for k,v in org_data.items()][:10]) 138 | 139 | new_format = [] 140 | 141 | # if use speaker description -> load raw data and preprocess 142 | if prompting_type not in ["default" ]: 143 | desc_speaker_data = json.load(open(f'{folder_data}/{data_name}.{d_type}_{prompting_type}_{extract_prompting_llm_id}.json')) 144 | processed_desc_speaker_data = {} 145 | if desc_speaker_data is not None and "spdesc" in prompting_type: 146 | for s_id, desc_all_conv in desc_speaker_data.items(): 147 | processed_desc_speaker_data[s_id] = [preprocess_desc_speaker(spdesc) for spdesc in desc_all_conv] 148 | desc_speaker_data = processed_desc_speaker_data 149 | else: 150 | desc_speaker_data = None 151 | 152 | # path data out 153 | path_processed_data = raw_data.replace(".json", f".0shot_w{around_window}_{prompting_type}.jsonl") if path_data_out is None else path_data_out 154 | 155 | # prompting process function 156 | process_function_map = { 157 | "spdescV2": gen_spdescV2_prompting_messages, 158 | "default": gen_default_prompting_messages, 159 | } 160 | 161 | process_func = process_function_map.get(prompting_type, process_function_map['default']) 162 | print(f"- process prompting by {process_func.__name__}") 163 | 164 | for s_id, conv in org_data.items(): 165 | process_args = [data_name, conv, around_window, s_id, desc_speaker_data] 166 | samples = process_func(*process_args, **process_kwargs) 167 | new_format = new_format + samples 168 | 169 | with open(f'{path_processed_data}', 'wt') as f: 170 | new_format = [json.dumps(e) for e in new_format] 171 | f.write("\n".join(new_format)) 172 | 173 | 174 | # if __name__=="__main__": 175 | # process('train', around_window=5, use_spdesc=True) 176 | # process('test', around_window=5, use_spdesc=True) 177 | # process('valid', around_window=5, use_spdesc=True) 178 | 179 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## BiosERC: Integrating Biography Speakers Supported by LLMs for ERC Tasks 2 | In the Emotion Recognition in Conversation task, recent investigations have utilized attention mechanisms exploring relationships among utterances from intra- and inter-speakers for modeling emotional interaction between them. However, attributes such as speaker personality traits remain unexplored and present challenges in terms of their applicability to other tasks or compatibility with diverse model architectures. Therefore, this work introduces a novel framework named BiosERC, which investigates speaker characteristics in a conversation. By employing Large Language Models (LLMs), we extract the ``biographical information'' of the speaker within a conversation as supplementary knowledge injected into the model to classify emotional labels for each utterance. Our proposed method achieved state-of-the-art (SOTA) results on three famous benchmark datasets: IEMOCAP, MELD, and EmoryNLP, demonstrating the effectiveness and generalization of our model and showcasing its potential for adaptation to various conversation analysis tasks. 3 | 4 | Full paper here: [https://link.springer.com/chapter/10.1007/978-3-031-72344-5_19](https://link.springer.com/chapter/10.1007/978-3-031-72344-5_19) 5 | 6 | ## Results 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bioserc-integrating-biography-speakers/emotion-recognition-in-conversation-on-meld)](https://paperswithcode.com/sota/emotion-recognition-in-conversation-on-meld?p=bioserc-integrating-biography-speakers)
8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bioserc-integrating-biography-speakers/emotion-recognition-in-conversation-on-4)](https://paperswithcode.com/sota/emotion-recognition-in-conversation-on-4?p=bioserc-integrating-biography-speakers)
9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bioserc-integrating-biography-speakers/emotion-recognition-in-conversation-on)](https://paperswithcode.com/sota/emotion-recognition-in-conversation-on?p=bioserc-integrating-biography-speakers) 10 | 11 | Performance comparison between our proposed method and previous works on the test sets. 12 | | | | | | | 13 | | :--------------------------------------------- | :---: | :---------: | :----------: | :-------: | 14 | | **Methods** | | **IEMOCAP** | **EmoryNLP** | **MELD** | 15 | | HiTrans | | 64.50 | 36.75 | 61.94 | 16 | | DAG | | 68.03 | 39.02 | 63.65 | 17 | | DialogXL | | 65.94 | 34.73 | 62.14 | 18 | | DialogueEIN | | 68.93 | 38.92 | 65.37 | 19 | | SGED + DAG-ERC | | 68.53 | 40.24 | 65.46 | 20 | | S+PAGE | | 68.93 | 40.05 | 64.67 | 21 | | InstructERC _+(ft LLM)_ | | **71.39** | 41.39 | 69.15 | 22 | | | | | | | 23 | | Intra/inter ERC (baseline) ${[AccWR]}_{MLP}$ | | 67.65 | 39.33 | 64.58 | 24 | | ${BiosERC}_{BERT-based}$ | | 67.79 | 39.89 | 65.51 | 25 | | ${BiosERC + LoRA}_{Llama-2-7b}$ | | 69.02 | 41.44 | 68.72 | 26 | | ${BiosERC + LoRA}_{Llama-2-13b}$ | | 71.19 | **41.68** | **69.83** | 27 | | | | | | | 28 | 29 | ## Data 30 | unzip the file `data.zip` to extract data. 31 | - IEMOCAP 32 | Data structure examples: 33 | ```json 34 | { 35 | # this is first conversation 36 | "Ses05M_impro03": { 37 | "labels": [ 38 | 4, 39 | 2, 40 | 4, 41 | 4 42 | ], 43 | "sentences": [ 44 | "Guess what?", 45 | "what?", 46 | "I did it, I asked her to marry me.", 47 | "Yes, I did it." 48 | ], 49 | "genders": [ 50 | "M", 51 | "F", 52 | "M", 53 | "M", 54 | "F", 55 | ] 56 | }, 57 | 58 | # this is second conversation 59 | "Ses05M_impro03": { 60 | "labels": [ 61 | 4, 62 | 2, 63 | ], 64 | "sentences": [ 65 | "Guess what?", 66 | "what?", 67 | ], 68 | "genders": [ 69 | "M", 70 | "F", 71 | ] 72 | } 73 | } 74 | ``` 75 | 76 | ## Python ENV 77 | Init python environment 78 | ```cmd 79 | conda create --prefix=./env_py38 python=3.9 80 | conda activate ./env_py38 81 | pip install -r requirements.txt 82 | ``` 83 | 84 | ## Run 85 | 1. Init environment follow the above step. 86 | 2. Data peprocessing. 87 | 1. Put all the raw data to the folder `data/`. 88 | The overview of data structure: 89 | ``` 90 | . 91 | ├── data/ 92 | │   ├── meld.valid_spdescV2_Llama-2-70b-chat-hf.json # speaker biography will be generated by run `python src/llm_bio_extract.py` 93 | │   ├── meld.train_spdescV2_Llama-2-70b-chat-hf.json # speaker biography will be generated by run `python src/llm_bio_extract.py` 94 | │   ├── meld.test_spdescV2_Llama-2-70b-chat-hf.json # speaker biography will be generated by run `python src/llm_bio_extract.py` 95 | │   ├── meld.test.json 96 | │   ├── meld.train.json 97 | │   ├── meld.valid.json 98 | │   ├── ... 99 | │   ├── iemocap.test.json 100 | │   ├── iemocap.train.json 101 | │   └── iemocap.valid.json 102 | ├── src/ 103 | ├── finetuned_llm/ 104 | └── ... 105 | ``` 106 | 3. Train 107 | Run following command to train a new model. 108 | ```bash 109 | python src/llm_bio_extract.py # to extract speaker bio 110 | bash scrips/train_llm.sh # to train a llm model 111 | ``` 112 | > **Note**: Please check this scripts to check the setting and choose which data you want to run. 113 | 114 | 4. Infer trained model from huggingface 115 | Download model and data from [https://huggingface.co/phuongnm94/BiosERC](https://huggingface.co/phuongnm94/BiosERC) 116 | ```bash 117 | bash scrips/infer.sh # to infer a llm-based BiosERC: Llama-13b 118 | ``` 119 | or run 120 | ```bash 121 | infer_bioserc_bertbased.ipynb # to infer BiosERC bert based model 122 | ``` 123 | > **Note**: Please check all the path of data and models related. 124 | ## Citation 125 | 126 | ```bibtex 127 | @InProceedings{10.1007/978-3-031-72344-5_19, 128 | author="Xue, Jieying 129 | and Nguyen, Minh-Phuong 130 | and Matheny, Blake 131 | and Nguyen, Le-Minh", 132 | editor="Wand, Michael 133 | and Malinovsk{\'a}, Krist{\'i}na 134 | and Schmidhuber, J{\"u}rgen 135 | and Tetko, Igor V.", 136 | title="BiosERC: Integrating Biography Speakers Supported by LLMs for ERC Tasks", 137 | booktitle="Artificial Neural Networks and Machine Learning -- ICANN 2024", 138 | year="2024", 139 | publisher="Springer Nature Switzerland", 140 | address="Cham", 141 | pages="277--292", 142 | abstract="In the Emotion Recognition in Conversation task, recent investigations have utilized attention mechanisms exploring relationships among utterances from intra- and inter-speakers for modeling emotional interaction between them. However, attributes such as speaker personality traits remain unexplored and present challenges in terms of their applicability to other tasks or compatibility with diverse model architectures. Therefore, this work introduces a novel framework named BiosERC, which investigates speaker characteristics in a conversation. By employing Large Language Models (LLMs), we extract the ``biographical information'' of the speaker within a conversation as supplementary knowledge injected into the model to classify emotional labels for each utterance. Our proposed method achieved state-of-the-art (SOTA) results on three famous benchmark datasets: IEMOCAP, MELD, and EmoryNLP, demonstrating the effectiveness and generalization of our model and showcasing its potential for adaptation to various conversation analysis tasks. Our source code is available at https://github.com/yingjie7/BiosERC.", 143 | isbn="978-3-031-72344-5" 144 | } 145 | 146 | ``` 147 | 148 | ## Licensing Information 149 | - **MELD**: Licensed under [GPL-3.0](https://www.gnu.org/licenses/gpl-3.0.en.html). 150 | - **EMORYNLP**: Licensed under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0). 151 | - **IEMOCAP**: Licensed under a non-commercial research license. Refer to the official [IEMOCAP website](https://sail.usc.edu/iemocap/) for terms of use. 152 | -------------------------------------------------------------------------------- /src/ft_llm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from datasets import load_dataset 4 | import torch 5 | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments 6 | from trl import setup_chat_format, set_seed as trl_seed 7 | from peft import LoraConfig, AutoPeftModelForCausalLM 8 | from trl import SFTTrainer 9 | from transformers import set_seed as transf_seed 10 | from tqdm import tqdm 11 | from sklearn.metrics import f1_score 12 | import numpy as np 13 | from torch.utils.data import DataLoader 14 | from transformers.trainer_utils import EvalLoopOutput 15 | import random, glob 16 | from lightning import seed_everything 17 | 18 | from reformat_data_ft_llm import process 19 | 20 | def set_random_seed(seed: int): 21 | """set seeds for reproducibility""" 22 | random.seed(seed) 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | seed_everything(seed=seed) 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = False 29 | trl_seed(seed) 30 | transf_seed(seed) 31 | 32 | 33 | def formatting_prompts_func(samples): 34 | prompt_texts = [tokenizer.apply_chat_template( 35 | sample[:-1], tokenize=False, add_generation_prompt=True) for sample in samples["messages"]] 36 | 37 | print("=="*50) 38 | print(prompt_texts[-1]) 39 | print("=="*50) 40 | return prompt_texts 41 | 42 | def split_label(sample): 43 | tokenized_lb = tokenizer.encode(sample['messages'][-1]['content'], padding='max_length',max_length=10 ) 44 | sample['labels'] = tokenized_lb 45 | return sample 46 | 47 | class LLMErcTrainer(SFTTrainer): 48 | def __init__(self, *args, **kwargs): 49 | super().__init__(*args, **kwargs) 50 | 51 | self.data_process_args = argparse.Namespace( 52 | packing=False, 53 | dataset_text_field=None, 54 | max_seq_length=kwargs.get('max_seq_length', None), 55 | formatting_func=formatting_prompts_func, 56 | num_of_sequences=kwargs.get('num_of_sequences', 1024), 57 | chars_per_token=kwargs.get('chars_per_token', 3.6), 58 | remove_unused_columns=kwargs.get('args').remove_unused_columns if kwargs.get('args') is not None else True, 59 | dataset_kwargs=kwargs.get('dataset_kwargs', {}) 60 | ) 61 | self.eval_dataset = self._process_raw_data(kwargs.get('eval_dataset', None)) 62 | print("len(eval dataset) = ", len(self.eval_dataset)) 63 | 64 | def _process_raw_data(self, dataset): 65 | dataset2 = dataset.map(split_label) 66 | dataset = self._prepare_dataset( 67 | dataset=dataset, 68 | tokenizer=self.tokenizer, 69 | packing=False, 70 | dataset_text_field=None, 71 | max_seq_length=self.data_process_args.max_seq_length, 72 | formatting_func=self.data_process_args.formatting_func, 73 | num_of_sequences=self.data_process_args.num_of_sequences, 74 | chars_per_token=self.data_process_args.chars_per_token, 75 | remove_unused_columns=self.data_process_args.remove_unused_columns, 76 | **self.data_process_args.dataset_kwargs, 77 | ) 78 | dataset = dataset.add_column('labels', dataset2['labels']) 79 | return dataset 80 | 81 | def get_eval_dataloader(self, eval_dataset=None) -> DataLoader: 82 | if "input_ids" not in eval_dataset.column_names and "labels" not in eval_dataset.column_names: 83 | # this is raw data which need to preprocess 84 | eval_dataset = self._process_raw_data(eval_dataset) 85 | 86 | return super().get_eval_dataloader(eval_dataset) 87 | 88 | def evaluation_loop( 89 | self, 90 | dataloader: DataLoader, 91 | description: str, 92 | prediction_loss_only= None, 93 | ignore_keys = None, 94 | metric_key_prefix="eval", 95 | ) -> EvalLoopOutput: 96 | """ 97 | Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. 98 | 99 | Works both with or without labels. 100 | """ 101 | model = self.model 102 | model = model.to(dtype=torch.bfloat16) 103 | 104 | model.eval() 105 | 106 | # losses/preds/labels on CPU (final containers) 107 | all_preds = [] 108 | all_labels = [] 109 | all_raw_decoded = [] 110 | 111 | def post_process(str_out): 112 | try: 113 | gen_text = str_out.split("assistant\n")[-1].split("<|im_end|>")[0] 114 | except: 115 | gen_text = "error" 116 | return gen_text 117 | 118 | # Main evaluation loop 119 | with torch.no_grad(): 120 | for step, inputs in enumerate(tqdm(dataloader)): 121 | inputs = self._prepare_inputs(inputs) 122 | gen_kwargs = {'max_new_tokens': 10, 123 | 'do_sample': False, 124 | 'eos_token_id': self.tokenizer.eos_token_id, 125 | 'pad_token_id': self.tokenizer.pad_token_id, 126 | "temperature": 0.1, 127 | } 128 | generated_tokens = model.generate( 129 | inputs["input_ids"], 130 | attention_mask=inputs["attention_mask"], 131 | **gen_kwargs, 132 | ) 133 | labels = inputs.pop("labels") 134 | str_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) 135 | 136 | raw_decoded = [e for e in self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)] 137 | str_decoded = [post_process(e) for e in raw_decoded] 138 | all_preds += str_decoded 139 | all_labels += str_labels 140 | all_raw_decoded += raw_decoded 141 | num_samples = len(dataloader) 142 | 143 | f1_weighted = f1_score( 144 | all_labels, 145 | all_preds, 146 | average=f"weighted", 147 | ) 148 | metrics = { f"{metric_key_prefix}_weighted-f1": f1_weighted } 149 | 150 | json.dump({"metrics": metrics, 151 | "detail_pred": list(zip(all_preds, all_labels, all_raw_decoded))}, 152 | open(f"{self.args.output_dir}/result_{metric_key_prefix}_step-{self.state.global_step}.json", "wt"), indent=1) 153 | 154 | # free the memory again 155 | del model 156 | torch.cuda.empty_cache() 157 | return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) 158 | 159 | 160 | if __name__=='__main__': 161 | 162 | parser = argparse.ArgumentParser(description='Process ...') 163 | parser.add_argument('--do_train', action="store_true", help='fine tuning a LLM model with LoRA', default=False) 164 | parser.add_argument('--do_eval_test', action="store_true", help='eval on test set', default=False) 165 | parser.add_argument('--do_eval_dev', action="store_true", help='eval on dev set', default=False) 166 | parser.add_argument('--ft_model_path', type=str, default=None, help='fintuned model path') 167 | parser.add_argument('--ft_model_id', type=str, default=None, help='fintuned model id for saving after train it') 168 | parser.add_argument('--prompting_type', type=str, default='spdescV2', help='prompting style in {cot, fewshot, zeroshot}') 169 | parser.add_argument('--base_model_id', type=str, default='meta-llama/Llama-2-7b-hf', help='base llm model id') 170 | parser.add_argument('--extract_prompting_llm_id', type=str, default='Llama-2-7b-chat-hf', help='base llm model id') 171 | parser.add_argument('--epoch', type=int, default=None, help='training epoch') 172 | parser.add_argument('--max_steps', type=int, default=None, help='training steps') 173 | parser.add_argument('--lr_scheduler', type=str, default='constant', help='learning rate scheduler') 174 | parser.add_argument('--lr', type=float, default=2e-4, help='learning rate value') 175 | parser.add_argument('--seed', type=int, default=42, help='random seed value') 176 | parser.add_argument('--kshot', type=int, default=0, help='k shot examples for llm') 177 | parser.add_argument('--lora_r', type=int, default=32, help='lora rank') 178 | parser.add_argument('--eval_delay', type=int, default=200, help='eval delay') 179 | parser.add_argument('--window', type=int, default=5, help='local context window size') 180 | parser.add_argument('--max_seq_len', type=int, default=None, help='max sequence length for chunking/packing') 181 | parser.add_argument('--re_gen_data', action="store_true", help='re generate data', default=False) 182 | parser.add_argument('--data_name', type=str, help='data name in {iemocap, meld, emorynlp}', default='iemocap') 183 | parser.add_argument('--data_folder', type=str, help='path folder save all data', default='./data/') 184 | parser.add_argument('--output_folder', type=str, help='path folder save all data', default='./finetuned_llm/') 185 | 186 | args, unknown = parser.parse_known_args() 187 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 188 | if args.prompting_type == 'zeroshot': 189 | args.kshot = 0 190 | print(args) 191 | 192 | set_random_seed(args.seed) 193 | 194 | all_path_folder_preprocessed_data = [f"{args.data_folder}/{args.data_name}.{d_type}.{args.kshot}shot_w{args.window}_{args.prompting_type}.jsonl" \ 195 | for d_type in [ 'train' , 'valid', 'test']] 196 | if args.re_gen_data: 197 | process(all_path_folder_preprocessed_data, args) 198 | 199 | # Load jsonl data from disk 200 | dataset = load_dataset("json", data_files=all_path_folder_preprocessed_data[0], split="train", cache_dir=f'{args.output_folder}/{args.ft_model_id}') 201 | valid_dataset = load_dataset("json", data_files=all_path_folder_preprocessed_data[1], split="train", cache_dir=f'{args.output_folder}/{args.ft_model_id}') 202 | test_dataset = load_dataset("json", data_files=all_path_folder_preprocessed_data[2], split="train", cache_dir=f'{args.output_folder}/{args.ft_model_id}') 203 | 204 | 205 | # Load model and tokenizer 206 | model_id = args.base_model_id # "codellama/CodeLlama-7b-hf" # or `mistralai/Mistral-7B-v0.1` 207 | if args.do_train: 208 | tensor_data_type = torch.bfloat16 209 | bnb_config = BitsAndBytesConfig( 210 | load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=tensor_data_type 211 | ) 212 | tokenizer = AutoTokenizer.from_pretrained(model_id) 213 | if args.ft_model_path is not None: 214 | model = AutoPeftModelForCausalLM.from_pretrained( 215 | args.ft_model_path, 216 | device_map="auto", 217 | torch_dtype=tensor_data_type, 218 | load_in_8bit=True 219 | ) 220 | else: 221 | model = AutoModelForCausalLM.from_pretrained( 222 | model_id, 223 | device_map="auto", 224 | attn_implementation="flash_attention_2", 225 | torch_dtype=tensor_data_type, 226 | quantization_config=bnb_config 227 | ) 228 | else: 229 | tensor_data_type = torch.float32 # for reduce the miss matching of ouputs of batch inference 230 | ft_model_path = f"{args.output_folder}/{args.ft_model_id}" if args.ft_model_path is None else args.ft_model_path 231 | tokenizer = AutoTokenizer.from_pretrained(ft_model_path) 232 | model = AutoPeftModelForCausalLM.from_pretrained( 233 | ft_model_path, 234 | device_map="auto", 235 | torch_dtype=tensor_data_type 236 | ) 237 | 238 | 239 | # tokenizer = AutoTokenizer.from_pretrained(model_id) 240 | tokenizer.padding_side = 'left' 241 | 242 | # # set chat template to OAI chatML, remove if you start from a fine-tuned model 243 | model, tokenizer = setup_chat_format(model, tokenizer) 244 | 245 | # training config 246 | # LoRA config based on QLoRA paper & Sebastian Raschka experiment 247 | peft_config = LoraConfig( 248 | lora_alpha=128, 249 | lora_dropout=0.05, 250 | r=args.lora_r, 251 | bias="none", 252 | target_modules="all-linear", 253 | task_type="CAUSAL_LM", 254 | ) 255 | 256 | training_args = TrainingArguments( 257 | output_dir=f'{args.output_folder}/{args.ft_model_id}', # directory to save and repository id 258 | num_train_epochs= args.epoch, # number of training epochs 259 | max_steps=args.max_steps, 260 | per_device_train_batch_size=4, # batch size per device during training 261 | per_device_eval_batch_size=1, 262 | gradient_accumulation_steps=4, # number of steps before performing a backward/update pass 263 | gradient_checkpointing=True, # use gradient checkpointing to save memory 264 | save_total_limit=1, 265 | optim="adamw_torch_fused", # use fused adamw optimizer 266 | eval_delay=args.eval_delay, # log every 10 steps meld:200 267 | logging_steps=50, # log every 10 steps 268 | eval_steps=50, 269 | save_steps=50, 270 | load_best_model_at_end=True, 271 | metric_for_best_model='weighted-f1', 272 | greater_is_better=True, 273 | evaluation_strategy='steps', 274 | save_strategy="steps", # save checkpoint every epoch 275 | learning_rate=args.lr, # learning rate, based on QLoRA paper 276 | bf16=True, # use bfloat16 precision 277 | tf32=True, # use tf32 precision 278 | max_grad_norm=0.3, # max gradient norm based on QLoRA paper 279 | warmup_ratio=0.03, # warmup ratio based on QLoRA paper 280 | lr_scheduler_type=args.lr_scheduler, # use constant learning rate scheduler 281 | push_to_hub=False, # push model to hub ########################## 282 | group_by_length=True, 283 | report_to="tensorboard", # report metrics to tensorboard 284 | ) 285 | 286 | trainer = LLMErcTrainer( 287 | model=model, 288 | args=training_args, 289 | train_dataset=dataset, 290 | eval_dataset=valid_dataset, 291 | neftune_noise_alpha=5, 292 | peft_config=peft_config, 293 | max_seq_length=args.max_seq_len, 294 | tokenizer=tokenizer, 295 | packing=True, 296 | dataset_kwargs={ 297 | "add_special_tokens": False, # We template with special tokens 298 | "append_concat_token": False, # No need to add additional separator token 299 | } 300 | ) 301 | 302 | # n_trainable_pr, total_pr = get_peft_model(model, peft_config).get_nb_trainable_parameters() 303 | # print(f"total params: {n_trainable_pr}, trainable params {total_pr}, percentage={n_trainable_pr/total_pr*100}") 304 | 305 | if args.do_train: 306 | # start training, the model will be automatically saved to the hub and the output directory 307 | # trainer.train(ignore_keys_for_eval=[]) 308 | trainer.train(resume_from_checkpoint=True if len(glob.glob(f'{args.output_folder}/{args.ft_model_id}/checkpoint-*')) > 0 else None) 309 | 310 | # save model 311 | trainer.save_model() 312 | 313 | 314 | ft_model_path = f'{args.output_folder}/{args.ft_model_id}' if args.ft_model_path is None else args.ft_model_path 315 | 316 | if args.do_eval_test: 317 | result = trainer.evaluate(test_dataset, metric_key_prefix='test') 318 | print(f"Test result = {result}") 319 | 320 | if args.do_eval_dev: 321 | result = trainer.evaluate(valid_dataset, metric_key_prefix='valid') 322 | print(f"Valid result = {result}") 323 | -------------------------------------------------------------------------------- /scripts/infer_bioserc_bertbased.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json, argparse, sys\n", 10 | "import pandas as pd\n", 11 | "sys.path.append(\"../src\")" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 6, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "config_path ../finetuned_llm/bioserc_bert_based/version_13/hparams.yaml\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "import torch\n", 29 | "from baseline.src.main import *\n", 30 | "import yaml\n", 31 | "\n", 32 | "type_data = 'valid'\n", 33 | "\n", 34 | "config_path = '../finetuned_llm/bioserc_bert_based/version_13/hparams.yaml'\n", 35 | "model_path = \"../finetuned_llm/bioserc_bert_based/roberta-large-meld-valid/f1=67.41.ckpt\"\n", 36 | "\n", 37 | "print('config_path',config_path)\n", 38 | "\n", 39 | "with open(config_path, \"r\") as yamlfile:\n", 40 | " model_configs = argparse.Namespace(**yaml.load(yamlfile, Loader=yaml.FullLoader))\n", 41 | " \n", 42 | "model_configs.data_folder = '../data/'\n", 43 | "# model_configs.window_ct = 2\n", 44 | "# model_configs.speaker_description = False\n", 45 | "# model_configs.llm_context = False\n", 46 | "# model_configs.data_name_pattern = \"meld.{}.json\"\n", 47 | "dataset_name = model_configs.data_name_pattern.split(\".\")[0]\n" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 7, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "meld\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "print(dataset_name)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 8, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "\n", 74 | "\n", 75 | "# meld\n", 76 | "label2id = {'neutral': 0, 'surprise': 1, 'fear': 2, 'sadness': 3, 'joy': 4, 'disgust': 5, 'anger':6}\n", 77 | "id2label = ['neutral', 'surprise', 'fear', 'sadness', 'joy', 'disgust', 'anger']\n", 78 | "\n", 79 | "\n", 80 | "# # emorynlp\n", 81 | "# label2id = {'Joyful': 0, 'Mad': 1, 'Peaceful': 2, 'Neutral': 3, 'Sad': 4, 'Powerful': 5, 'Scared': 6}\n", 82 | "# id2label = ['Joyful', 'Mad', 'Peaceful', 'Neutral', 'Sad', 'Powerful', 'Scared']\n", 83 | "# # iemocap\n", 84 | "# label2id = {'hap':0, 'sad':1, 'neu':2, 'ang':3, 'exc':4, 'fru':5}\n", 85 | "# id2label = ['hap', 'sad', 'neu', 'ang', 'exc', 'fru']\n" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 9, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "\n", 95 | "from torch.utils.data import DataLoader\n", 96 | "from transformers import AutoTokenizer, AutoModel\n", 97 | "import json\n", 98 | "import random\n", 99 | "import argparse\n", 100 | "from sklearn.metrics import f1_score\n", 101 | "\n", 102 | "import pytorch_lightning as pl\n", 103 | "from pytorch_lightning import Trainer\n", 104 | "\n", 105 | "bert_tokenizer = AutoTokenizer.from_pretrained(model_configs.pre_trained_model_name)\n", 106 | "\n", 107 | "data_loader_valid = BatchPreprocessor(bert_tokenizer, model_configs=model_configs, data_type=type_data)\n", 108 | "raw_data = BatchPreprocessor.load_raw_data(f\"{model_configs.data_folder}/{model_configs.data_name_pattern.format(type_data)}\")\n", 109 | "valid_loader = DataLoader(raw_data,\n", 110 | " batch_size=model_configs.batch_size, collate_fn=data_loader_valid, shuffle=False)\n", 111 | "\n", 112 | "data_loader_test=BatchPreprocessor(bert_tokenizer, model_configs=model_configs, data_type='test')\n", 113 | "raw_data_test = BatchPreprocessor.load_raw_data(f\"{model_configs.data_folder}/{model_configs.data_name_pattern.format('test')}\")\n", 114 | "test_loader = DataLoader(raw_data_test,\n", 115 | " batch_size=model_configs.batch_size, collate_fn=data_loader_test, shuffle=False)\n" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 10, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "name": "stderr", 125 | "output_type": "stream", 126 | "text": [ 127 | "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n", 128 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 129 | ] 130 | } 131 | ], 132 | "source": [ 133 | "model_configs.spdesc_aggregate_method = 'static'\n", 134 | "# model_configs.llm_context = False\n", 135 | "# model_configs.speaker_description=False\n", 136 | "model = EmotionClassifier(model_configs) \n", 137 | "# model.model_configs = model_configs" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 11, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stderr", 147 | "output_type": "stream", 148 | "text": [ 149 | "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n", 150 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", 151 | "GPU available: True (cuda), used: True\n", 152 | "TPU available: False, using: 0 TPU cores\n", 153 | "IPU available: False, using: 0 IPUs\n", 154 | "HPU available: False, using: 0 HPUs\n", 155 | "You are using a CUDA device ('NVIDIA A40') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", 156 | "Missing logger folder: /home/phuongnm/BiosERC/scripts/lightning_logs\n", 157 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-534884ef-a89c-23e9-885a-13ca043a4659,GPU-e2bff860-9659-2d80-4f5d-b02e78499047]\n", 158 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=51` in the `DataLoader` to improve performance.\n" 159 | ] 160 | }, 161 | { 162 | "data": { 163 | "application/vnd.jupyter.widget-view+json": { 164 | "model_id": "eddbf947e2da4788898d3e3944bd2045", 165 | "version_major": 2, 166 | "version_minor": 0 167 | }, 168 | "text/plain": [ 169 | "Testing: | …" 170 | ] 171 | }, 172 | "metadata": {}, 173 | "output_type": "display_data" 174 | }, 175 | { 176 | "name": "stderr", 177 | "output_type": "stream", 178 | "text": [ 179 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/torch/nn/modules/transformer.py:384: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:177.)\n", 180 | " output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)\n", 181 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 4. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 182 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 23. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 183 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 8. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 184 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 6. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 185 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 18. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 186 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 15. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 187 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 12. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 188 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 14. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 189 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 17. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 190 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 10. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 191 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 7. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 192 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 27. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 193 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 11. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 194 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 9. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 195 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 16. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 196 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 26. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 197 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 19. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 198 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 13. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 199 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 21. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 200 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 2. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 201 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 5. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 202 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 20. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 203 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 22. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 204 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 28. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n" 205 | ] 206 | }, 207 | { 208 | "data": { 209 | "text/html": [ 210 | "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
211 |        "┃        Test metric               DataLoader 0        ┃\n",
212 |        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
213 |        "│         hp_metric              67.41372680664062     │\n",
214 |        "│          test/f1               67.41372680664062     │\n",
215 |        "│        train/loss             1.7146971225738525     │\n",
216 |        "│        valid/loss             1.7146971225738525     │\n",
217 |        "└───────────────────────────┴───────────────────────────┘\n",
218 |        "
\n" 219 | ], 220 | "text/plain": [ 221 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 222 | "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", 223 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 224 | "│\u001b[36m \u001b[0m\u001b[36m hp_metric \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 67.41372680664062 \u001b[0m\u001b[35m \u001b[0m│\n", 225 | "│\u001b[36m \u001b[0m\u001b[36m test/f1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 67.41372680664062 \u001b[0m\u001b[35m \u001b[0m│\n", 226 | "│\u001b[36m \u001b[0m\u001b[36m train/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.7146971225738525 \u001b[0m\u001b[35m \u001b[0m│\n", 227 | "│\u001b[36m \u001b[0m\u001b[36m valid/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.7146971225738525 \u001b[0m\u001b[35m \u001b[0m│\n", 228 | "└───────────────────────────┴───────────────────────────┘\n" 229 | ] 230 | }, 231 | "metadata": {}, 232 | "output_type": "display_data" 233 | }, 234 | { 235 | "name": "stderr", 236 | "output_type": "stream", 237 | "text": [ 238 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-534884ef-a89c-23e9-885a-13ca043a4659,GPU-e2bff860-9659-2d80-4f5d-b02e78499047]\n" 239 | ] 240 | }, 241 | { 242 | "name": "stdout", 243 | "output_type": "stream", 244 | "text": [ 245 | "[{'train/loss': 1.7146971225738525, 'valid/loss': 1.7146971225738525, 'test/f1': 67.41372680664062, 'hp_metric': 67.41372680664062}]\n" 246 | ] 247 | }, 248 | { 249 | "name": "stderr", 250 | "output_type": "stream", 251 | "text": [ 252 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=51` in the `DataLoader` to improve performance.\n" 253 | ] 254 | }, 255 | { 256 | "data": { 257 | "application/vnd.jupyter.widget-view+json": { 258 | "model_id": "5baad8c0f13441d19e1d6d1afeed4d9e", 259 | "version_major": 2, 260 | "version_minor": 0 261 | }, 262 | "text/plain": [ 263 | "Testing: | …" 264 | ] 265 | }, 266 | "metadata": {}, 267 | "output_type": "display_data" 268 | }, 269 | { 270 | "name": "stderr", 271 | "output_type": "stream", 272 | "text": [ 273 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 24. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 274 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 40. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 275 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 25. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", 276 | "/home/phuongnm/py_envs/python/env_llm/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 3. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n" 277 | ] 278 | }, 279 | { 280 | "data": { 281 | "text/html": [ 282 | "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
283 |        "┃        Test metric               DataLoader 0        ┃\n",
284 |        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
285 |        "│         hp_metric              65.95140075683594     │\n",
286 |        "│          test/f1               65.95140075683594     │\n",
287 |        "│        train/loss             1.7551467418670654     │\n",
288 |        "│        valid/loss             1.7551467418670654     │\n",
289 |        "└───────────────────────────┴───────────────────────────┘\n",
290 |        "
\n" 291 | ], 292 | "text/plain": [ 293 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", 294 | "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", 295 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", 296 | "│\u001b[36m \u001b[0m\u001b[36m hp_metric \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 65.95140075683594 \u001b[0m\u001b[35m \u001b[0m│\n", 297 | "│\u001b[36m \u001b[0m\u001b[36m test/f1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 65.95140075683594 \u001b[0m\u001b[35m \u001b[0m│\n", 298 | "│\u001b[36m \u001b[0m\u001b[36m train/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.7551467418670654 \u001b[0m\u001b[35m \u001b[0m│\n", 299 | "│\u001b[36m \u001b[0m\u001b[36m valid/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.7551467418670654 \u001b[0m\u001b[35m \u001b[0m│\n", 300 | "└───────────────────────────┴───────────────────────────┘\n" 301 | ] 302 | }, 303 | "metadata": {}, 304 | "output_type": "display_data" 305 | }, 306 | { 307 | "name": "stdout", 308 | "output_type": "stream", 309 | "text": [ 310 | "[{'train/loss': 1.7551467418670654, 'valid/loss': 1.7551467418670654, 'test/f1': 65.95140075683594, 'hp_metric': 65.95140075683594}]\n" 311 | ] 312 | } 313 | ], 314 | "source": [ 315 | "\n", 316 | "import json \n", 317 | "import itertools \n", 318 | "model = EmotionClassifier.load_from_checkpoint(model_path, strict=False, model_configs=model_configs)\n", 319 | "trainer = Trainer(max_epochs=1, accelerator=\"gpu\", devices=1, )\n", 320 | "\n", 321 | "print(trainer.test(model, valid_loader))\n", 322 | "print(trainer.test(model, test_loader))\n" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [] 331 | } 332 | ], 333 | "metadata": { 334 | "kernelspec": { 335 | "display_name": "env_llm", 336 | "language": "python", 337 | "name": "python3" 338 | }, 339 | "language_info": { 340 | "codemirror_mode": { 341 | "name": "ipython", 342 | "version": 3 343 | }, 344 | "file_extension": ".py", 345 | "mimetype": "text/x-python", 346 | "name": "python", 347 | "nbconvert_exporter": "python", 348 | "pygments_lexer": "ipython3", 349 | "version": "3.9.18" 350 | }, 351 | "orig_nbformat": 4 352 | }, 353 | "nbformat": 4, 354 | "nbformat_minor": 2 355 | } 356 | -------------------------------------------------------------------------------- /src/llm_bio_extract.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import os 4 | from torch.utils.data import DataLoader 5 | import traceback 6 | from tqdm import tqdm 7 | import json 8 | from transformers import LlamaTokenizer, AutoModel, AutoTokenizer, LlamaForCausalLM 9 | 10 | import torch 11 | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, AutoConfig 12 | 13 | 14 | dataset_name = 'iemocap' 15 | data_folder = './data/' 16 | prompt_type = 'spdescV2' 17 | 18 | 19 | print("Loading model ...") 20 | # trained with chat and instruction 21 | model_name = 'meta-llama/Llama-2-70b-chat-hf' 22 | # model_name = 'meta-llama/Meta-Llama-3-8B-Instruct' # standard model 23 | tensor_data_type = torch.bfloat16 24 | bnb_config = BitsAndBytesConfig( 25 | load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=tensor_data_type 26 | ) 27 | model = LlamaForCausalLM.from_pretrained( 28 | model_name, 29 | # return_dict=True, 30 | load_in_8bit=True, 31 | device_map="auto", 32 | # low_cpu_mem_usage=True, 33 | ) 34 | 35 | tokenizer = AutoTokenizer.from_pretrained(model_name) 36 | tokenizer.pad_token = tokenizer.eos_token 37 | 38 | 39 | class BatchPreprocessor(object): 40 | def __init__(self, tokenizer, dataset_name=None, window_ct=2) -> None: 41 | self.tokenizer = tokenizer 42 | self.separate_token_id = self.tokenizer.convert_tokens_to_ids("
") 43 | self.dataset_name = dataset_name 44 | self.window_ct = window_ct 45 | 46 | @staticmethod 47 | def load_raw_data(path_data): 48 | raw_data = json.load(open(path_data)) 49 | if isinstance(raw_data, dict): 50 | new_data_list = [] 51 | for k, v in raw_data.items(): 52 | v['s_id'] = k 53 | new_data_list.append(v) 54 | return new_data_list 55 | elif isinstance(raw_data, list): 56 | return raw_data 57 | 58 | 59 | @staticmethod 60 | def get_speaker_name(s_id, gender, data_name): 61 | if data_name == "iemocap": 62 | # iemocap: label index mapping = {'hap':0, 'sad':1, 'neu':2, 'ang':3, 'exc':4, 'fru':5} 63 | speaker = { 64 | "Ses01": {"F": "Mary", "M": "James"}, 65 | "Ses02": {"F": "Patricia", "M": "John"}, 66 | "Ses03": {"F": "Jennifer", "M": "Robert"}, 67 | "Ses04": {"F": "Linda", "M": "Michael"}, 68 | "Ses05": {"F": "Elizabeth", "M": "William"}, 69 | } 70 | s_id_first_part = s_id[:5] 71 | return speaker[s_id_first_part][gender].upper() 72 | elif data_name in ['meld', "emorynlp"]: 73 | # emorynlp: label index mapping = {'Joyful': 0, 'Mad': 1, 'Peaceful': 2, 'Neutral': 3, 'Sad': 4, 'Powerful': 5, 'Scared': 6} 74 | # meld: label index mapping = {'neutral': 0, 'surprise': 1, 'fear': 2, 'sadness': 3, 'joy': 4, 'disgust': 5, 'anger':6} 75 | gender_idx = gender.index(1) 76 | return f"SPEAKER_{gender_idx}" 77 | elif data_name=='dailydialog': 78 | # dailydialog: {'no_emotion': 0, 'happiness': 1, 'sadness': 2, 'surprise': 3, 'anger': 4, 'fear': 5, 'disgust':6} 79 | return f"SPEAKER_{gender}" 80 | 81 | def sentence_mixed_by_surrounding(self, sentences, around_window, s_id, genders, data_name): 82 | new_sentences = [] 83 | for i, cur_sent in enumerate(sentences): 84 | tmp_s = "" 85 | for j in range(max(0, i-around_window), min(len(sentences), i+around_window+1)): 86 | if i == j: 87 | tmp_s += " " 88 | tmp_s += f" {self.get_speaker_name(s_id, genders[j], data_name=data_name)}: {sentences[j]}" 89 | if i == j: 90 | tmp_s += " " 91 | new_sentences.append(tmp_s) 92 | return new_sentences 93 | 94 | def __call__(self, batch): 95 | raw_sentences = [] 96 | raw_sentences_flatten = [] 97 | labels = [] 98 | 99 | # masked tensor 100 | lengths = [len(sample['sentences']) for sample in batch] 101 | max_len_conversation = max(lengths) 102 | padding_utterance_masked = torch.BoolTensor([[False]*l_i+ [True]*(max_len_conversation - l_i) for l_i in lengths]) 103 | 104 | # collect all sentences 105 | # - intra speaker 106 | intra_speaker_masekd_all = torch.BoolTensor(len(batch), max_len_conversation,max_len_conversation) 107 | for i, sample in enumerate(batch): 108 | sentences_mixed_arround = self.sentence_mixed_by_surrounding(sample['sentences'], 109 | around_window=self.window_ct, 110 | s_id=sample['s_id'], 111 | genders=sample['genders'], 112 | data_name=self.dataset_name) 113 | 114 | # conversation padding 115 | padded_conversation = sentences_mixed_arround + [""]* (max_len_conversation - lengths[i]) 116 | raw_sentences.append(padded_conversation) 117 | raw_sentences_flatten += padded_conversation 118 | 119 | # label padding 120 | labels += [int(label) for label in sample['labels']] + [-1]* (max_len_conversation - lengths[i]) 121 | 122 | # speaker 123 | intra_speaker_masekd= torch.BoolTensor(len(padded_conversation),len(padded_conversation)).fill_(False) 124 | for j in range(len( sample['genders'])): 125 | for k in range(len( sample['genders'])): 126 | gender_j = sample['genders'][j] 127 | gender_k = sample['genders'][k] 128 | 129 | if gender_j == gender_k: 130 | intra_speaker_masekd[j][k] = True 131 | else: 132 | intra_speaker_masekd[j][k] = False 133 | 134 | intra_speaker_masekd_all[i] = intra_speaker_masekd 135 | 136 | if len(labels)!= len(raw_sentences_flatten): 137 | print('len(labels)!= len(raw_sentences_flatten)') 138 | 139 | # utterance vectorizer 140 | # v_single_sentences = self._encoding(sample['sentences']) 141 | contextual_sentences_ids = self.tokenizer(raw_sentences_flatten, padding='longest', max_length=512, truncation=True, return_tensors='pt') 142 | sent_indices, word_indices = torch.where(contextual_sentences_ids['input_ids'] == self.separate_token_id) 143 | gr_sent_indices = [[] for e in range(len(raw_sentences_flatten))] 144 | for sent_idx, w_idx in zip (sent_indices, word_indices): 145 | gr_sent_indices[sent_idx].append(w_idx.item()) 146 | 147 | cur_sentence_indexes_masked = torch.BoolTensor(contextual_sentences_ids['input_ids'].shape).fill_(False) 148 | for i in range(contextual_sentences_ids['input_ids'].shape[0]): 149 | if raw_sentences_flatten[i] =='': 150 | cur_sentence_indexes_masked[i][gr_sent_indices[i][0]] = True 151 | continue 152 | for j in range(contextual_sentences_ids['input_ids'].shape[1]): 153 | if gr_sent_indices[i][0] <= j <= gr_sent_indices[i][1]: 154 | cur_sentence_indexes_masked[i][j] = True 155 | 156 | return (contextual_sentences_ids, torch.LongTensor(labels), padding_utterance_masked, intra_speaker_masekd_all, cur_sentence_indexes_masked, raw_sentences) 157 | 158 | 159 | class BatchPreprocessorLLM(BatchPreprocessor): 160 | def __init__(self, tokenizer, dataset_name=None, window_ct=2, emotion_labels=[]) -> None: 161 | self.tokenizer = tokenizer 162 | self.separate_token_id = self.tokenizer.convert_tokens_to_ids("") 163 | self.dataset_name = dataset_name 164 | self.window_ct = window_ct 165 | self.emotion_labels = emotion_labels 166 | self.printted = False 167 | 168 | @staticmethod 169 | def load_raw_data(path_data): 170 | raw_data = json.load(open(path_data)) 171 | if isinstance(raw_data, dict): 172 | new_data_list = [] 173 | for k, v in raw_data.items(): 174 | v['s_id'] = k 175 | new_data_list.append(v) 176 | return new_data_list 177 | elif isinstance(raw_data, list): 178 | return raw_data 179 | 180 | @staticmethod 181 | def get_speaker_name(s_id, gender, data_name): 182 | if data_name == "iemocap": 183 | # iemocap: label index mapping = {'hap':0, 'sad':1, 'neu':2, 'ang':3, 'exc':4, 'fru':5} 184 | speaker = { 185 | "Ses01": {"F": "Mary", "M": "James"}, 186 | "Ses02": {"F": "Patricia", "M": "John"}, 187 | "Ses03": {"F": "Jennifer", "M": "Robert"}, 188 | "Ses04": {"F": "Linda", "M": "Michael"}, 189 | "Ses05": {"F": "Elizabeth", "M": "William"}, 190 | } 191 | s_id_first_part = s_id[:5] 192 | return speaker[s_id_first_part][gender].upper() 193 | elif data_name in ['meld', "emorynlp"]: 194 | # emorynlp: label index mapping = {'Joyful': 0, 'Mad': 1, 'Peaceful': 2, 'Neutral': 3, 'Sad': 4, 'Powerful': 5, 'Scared': 6} 195 | # meld: label index mapping = {'neutral': 0, 'surprise': 1, 'fear': 2, 'sadness': 3, 'joy': 4, 'disgust': 5, 'anger':6} 196 | gender_idx = gender.index(1) 197 | return f"SPEAKER_{gender_idx}" 198 | elif data_name == 'dailydialog': 199 | # dailydialog: {'no_emotion': 0, 'happiness': 1, 'sadness': 2, 'surprise': 3, 'anger': 4, 'fear': 5, 'disgust':6} 200 | return f"SPEAKER_{gender}" 201 | 202 | def sentence_mixed_by_surrounding(self, sentences, around_window, s_id, genders, data_name): 203 | new_conversations = [] 204 | align_sents = [] 205 | for i, cur_sent in enumerate(sentences): 206 | tmp_s = "" 207 | for j in range(max(0, i-around_window), min(len(sentences), i+around_window+1)): 208 | u_j = f"{self.get_speaker_name(s_id, genders[j], data_name=data_name)}: {sentences[j]}" 209 | if i == j: 210 | align_sents.append(u_j) 211 | tmp_s += f"\n{u_j}" 212 | new_conversations.append(tmp_s) 213 | return new_conversations, align_sents 214 | 215 | def __call__(self, batch): 216 | raw_sentences = [] 217 | raw_sentences_flatten = [] 218 | labels = [] 219 | speaker_info = [] 220 | listener_info = [] 221 | 222 | # masked tensor 223 | lengths = [len(sample['sentences']) for sample in batch] 224 | max_len_conversation = max(lengths) 225 | padding_utterance_masked = torch.BoolTensor( 226 | [[False]*l_i + [True]*(max_len_conversation - l_i) for l_i in lengths]) 227 | 228 | # collect all sentences 229 | # - intra speaker 230 | flatten_data = [] 231 | intra_speaker_masekd_all = torch.BoolTensor( 232 | len(batch), max_len_conversation, max_len_conversation) 233 | for i, sample in enumerate(batch): 234 | new_conversations, align_sents = self.sentence_mixed_by_surrounding(sample['sentences'], 235 | around_window=self.window_ct, 236 | s_id=sample['s_id'], 237 | genders=sample['genders'], 238 | data_name=self.dataset_name) 239 | few_shot_example = """\n======= 240 | Context: Given predefined emotional label set [happy, sad, neutral, angry, excited, frustrated], and bellow conversation: 241 | " 242 | PATRICIA: You know, it's lovely here, the air is sweet. 243 | PATRICIA: No, not sorry. But, um. But I'm not gonna stay. 244 | JOHN: The trouble is, I planned on sort of sneaking up on you on a period of a week or so. But they take it for granted that we're all set. 245 | PATRICIA: I knew they would, your mother anyway. 246 | PATRICIA: Well, from her point of view, why else would I come? 247 | PATRICIA: I guess this is why I came. 248 | JOHN: I'm embarrassing you and I didn't want to tell it to you here. I wanted some place we'd never been before. A place where we'd be brand new to each other. 249 | PATRICIA: Well, you started to write me 250 | JOHN: You felt something that far back? 251 | PATRICIA: Every day since. 252 | JOHN: Ann, why didn't you let me know? 253 | JOHN: Let's drive someplace. I want to be alone with you. 254 | JOHN: No. Nothing like that. 255 | " 256 | 257 | Question: What is the emotion of the speaker at the utterance "PATRICIA: Well, from her point of view, why else would I come?"? 258 | Answer: neutral 259 | 260 | Question: What is the emotion of the speaker at the utterance "PATRICIA: I guess this is why I came."? 261 | Answer: happy 262 | 263 | Question: What is the emotion of the speaker at the utterance "JOHN: I'm embarrassing you and I didn't want to tell it to you here. I wanted some place we'd never been before. A place where we'd be brand new to each other."? 264 | Answer: excited 265 | """ 266 | for i_u, (conv, utterance) in enumerate(zip(new_conversations, align_sents)): 267 | prompt_extract_context_vect = few_shot_example + \ 268 | f"\n=======\nContext: Given predefined emotional label set [{', '.join(self.emotion_labels)}], and bellow conversation:\n\"{conv}\n\"\n\nQuestion: What is the emotion of the speaker at the utterance \"{utterance}\"?\nAnswer:" 269 | if not self.printted: 270 | print(prompt_extract_context_vect) 271 | self.printted = True 272 | 273 | inputs = self.tokenizer( 274 | prompt_extract_context_vect, return_tensors="pt") 275 | input_ids = inputs["input_ids"] 276 | flatten_data.append({ 277 | "s_id": sample['s_id'], 278 | "u_idx": i_u, 279 | "prompt_content": prompt_extract_context_vect, 280 | "input_ids": input_ids, 281 | } 282 | ) 283 | 284 | return flatten_data 285 | 286 | 287 | class BatchPreprocessorLLMSpeakerDescription(BatchPreprocessor): 288 | def __init__(self, tokenizer, dataset_name=None, window_ct=2, emotion_labels=[]) -> None: 289 | self.tokenizer = tokenizer 290 | self.separate_token_id = self.tokenizer.convert_tokens_to_ids("") 291 | self.dataset_name = dataset_name 292 | self.window_ct = window_ct 293 | self.emotion_labels = emotion_labels 294 | 295 | @staticmethod 296 | def load_raw_data(path_data): 297 | raw_data = json.load(open(path_data)) 298 | if isinstance(raw_data, dict): 299 | new_data_list = [] 300 | for k, v in raw_data.items(): 301 | v['s_id'] = k 302 | new_data_list.append(v) 303 | return new_data_list 304 | elif isinstance(raw_data, list): 305 | return raw_data 306 | 307 | @staticmethod 308 | def get_speaker_name(s_id, gender, data_name): 309 | if data_name == "iemocap": 310 | # iemocap: label index mapping = {'hap':0, 'sad':1, 'neu':2, 'ang':3, 'exc':4, 'fru':5} 311 | speaker = { 312 | "Ses01": {"F": "Mary", "M": "James"}, 313 | "Ses02": {"F": "Patricia", "M": "John"}, 314 | "Ses03": {"F": "Jennifer", "M": "Robert"}, 315 | "Ses04": {"F": "Linda", "M": "Michael"}, 316 | "Ses05": {"F": "Elizabeth", "M": "William"}, 317 | } 318 | s_id_first_part = s_id[:5] 319 | return speaker[s_id_first_part][gender].upper() 320 | elif data_name in ['meld', "emorynlp"]: 321 | # emorynlp: label index mapping = {'Joyful': 0, 'Mad': 1, 'Peaceful': 2, 'Neutral': 3, 'Sad': 4, 'Powerful': 5, 'Scared': 6} 322 | # meld: label index mapping = {'neutral': 0, 'surprise': 1, 'fear': 2, 'sadness': 3, 'joy': 4, 'disgust': 5, 'anger':6} 323 | gender_idx = gender.index(1) 324 | return f"SPEAKER_{gender_idx}" 325 | elif data_name == 'dailydialog': 326 | # dailydialog: {'no_emotion': 0, 'happiness': 1, 'sadness': 2, 'surprise': 3, 'anger': 4, 'fear': 5, 'disgust':6} 327 | return f"SPEAKER_{gender}" 328 | 329 | def preprocess(self, all_conversations): 330 | 331 | new_data = {} 332 | gr_by_len = {} 333 | for i, sample in enumerate(all_conversations): 334 | 335 | all_utterances = [] 336 | all_speaker_names = [] 337 | for i_u, u in enumerate(sample['sentences']): 338 | speaker_name = self.get_speaker_name( 339 | sample['s_id'], sample['genders'][i_u], self.dataset_name) 340 | u_full_name = f'{speaker_name}: {u}' 341 | all_utterances.append(u_full_name) 342 | all_speaker_names.append(speaker_name) 343 | 344 | full_conversation = "\n".join(all_utterances) 345 | prompts_speaker_description_word_ids = {} 346 | prompting_input = {} 347 | for speaker_name in set(all_speaker_names): 348 | prompting = "\nGiven this conversation between speakers: \n\"\n" + full_conversation + \ 349 | "\n\"\nIn overall of above conversation, what do you think about the characteristics speaker {}? (Note: provide an answer within 250 words)".format(speaker_name) 350 | 351 | prompts_speaker_description_word_ids[speaker_name] = self.tokenizer( 352 | prompting, return_tensors="pt")["input_ids"] 353 | prompting_input[speaker_name] = prompting 354 | 355 | # group by len for batch decode by llm 356 | if prompts_speaker_description_word_ids[speaker_name].shape[-1] not in gr_by_len: 357 | gr_by_len[prompts_speaker_description_word_ids[speaker_name].shape[-1]] = [] 358 | gr_by_len[prompts_speaker_description_word_ids[speaker_name].shape[-1]].append({ 359 | 'w_ids': prompts_speaker_description_word_ids[speaker_name], 360 | 'conv_id': sample['s_id'], 361 | 'type_data': sample['type_data'], 362 | "prompting_input": prompting, 363 | 'speaker_name': speaker_name, 364 | 'all_speaker_names': all_speaker_names 365 | }) 366 | 367 | return gr_by_len 368 | 369 | 370 | raw_data = [] 371 | for type_data in ['valid', 'test', 'train']: 372 | data_name_pattern = f'{dataset_name}.{type_data}' 373 | path_processed_data = f'{data_folder}/{data_name_pattern}_{prompt_type}_{model_name.split("/")[-1]}.json' 374 | 375 | org_raw_data = BatchPreprocessorLLMSpeakerDescription.load_raw_data( 376 | f"{data_folder}/{data_name_pattern}.json") 377 | 378 | if os.path.exists(path_processed_data): 379 | processed_data = json.load(open(path_processed_data, 'rt')) 380 | print( 381 | f'- sucessful processed {len(processed_data)}/{len(org_raw_data)} conversations in data-type ={type_data}') 382 | json.dump(processed_data, open( 383 | path_processed_data+"_backup.json", 'wt'), indent=2) 384 | org_raw_data = [e for e in org_raw_data if e['s_id'] 385 | not in processed_data] 386 | 387 | print( 388 | f'- Continue process {len(org_raw_data)} conversations in data-type ={type_data}') 389 | for e in org_raw_data: 390 | e['type_data'] = type_data 391 | raw_data = raw_data + org_raw_data 392 | 393 | data_preprocessor = BatchPreprocessorLLMSpeakerDescription(tokenizer, dataset_name=dataset_name, window_ct=4, 394 | emotion_labels=['happy', 'sad', 'neutral', 'angry', 'excited', 'frustrated']) 395 | 396 | gr_by_len = data_preprocessor.preprocess(raw_data) 397 | all_data = {} 398 | print_one_time = True 399 | for len_promting, speaker_promts in tqdm(gr_by_len.items()): 400 | for batch_size in [8, 5, 2, 1]: 401 | try: 402 | all_promtings_texts = [e['prompting_input'] 403 | for e in speaker_promts] 404 | data_loader = DataLoader(all_promtings_texts, 405 | batch_size=batch_size, 406 | shuffle=False) 407 | output_sp_desc = [] 408 | with torch.no_grad(): 409 | for i, speaker_promts_in_batch in enumerate(data_loader): 410 | # batch decoded by llm 411 | inputs = tokenizer(speaker_promts_in_batch, 412 | return_tensors="pt", padding=False) 413 | input_ids = inputs["input_ids"].to("cuda") 414 | with torch.no_grad(): 415 | outputs = model.generate(input_ids, max_new_tokens=300) 416 | output_text = tokenizer.batch_decode( 417 | outputs, skip_special_tokens=True) 418 | 419 | for j, e in enumerate(output_text): 420 | output_sp_desc.append( 421 | e.replace(all_promtings_texts[j], "")) 422 | 423 | if print_one_time: 424 | print(output_text) 425 | print(output_sp_desc) 426 | print_one_time = False 427 | 428 | for i, out in enumerate(output_sp_desc): 429 | speaker_promts[i]['sp_desc'] = out 430 | break 431 | 432 | except Exception as e: 433 | traceback.print_exc() 434 | print(e) 435 | if batch_size == 1: 436 | print(["Errr "]*10) 437 | 438 | for type_data in ['valid', 'test', 'train']: 439 | data_name_pattern = f'{dataset_name}.{type_data}' 440 | path_processed_data = f'{data_folder}/{data_name_pattern}_{prompt_type}_{model_name.split("/")[-1]}.json' 441 | 442 | processed_data = {} 443 | if os.path.exists(path_processed_data): 444 | processed_data = json.load(open(path_processed_data, 'rt')) 445 | print( 446 | f'- load processed [old] {len(processed_data)} conversations in data-type ={type_data}') 447 | 448 | all_data = {} 449 | for len_promting, speaker_promts in gr_by_len.items(): 450 | for description in speaker_promts: 451 | if type_data != description['type_data']: 452 | continue 453 | 454 | if description['conv_id'] not in all_data: 455 | all_data[description['conv_id']] = { 456 | 'all_speaker_names': description['all_speaker_names'], 457 | 'vocab_sp2desc': {} 458 | } 459 | all_data[description['conv_id'] 460 | ]['vocab_sp2desc'][description['speaker_name']] = description['sp_desc'] 461 | 462 | print( 463 | f'- sucessful processed [new] {len(all_data)} conversations in data-type ={type_data}') 464 | # json.dump(all_data, open(f'{path_data}_new.json', 'wt'), indent=2) 465 | 466 | all_data_new = {} 467 | for k, v in all_data.items(): 468 | all_data_new[k] = [] 469 | for sp_name in v['all_speaker_names']: 470 | all_data_new[k].append(v['vocab_sp2desc'][sp_name]) 471 | 472 | print( 473 | f'- update processed [new] {len(all_data_new)} + [old] {len(processed_data)} conversations in data-type ={type_data}') 474 | all_data_new.update(processed_data) 475 | json.dump(all_data_new, open(f'{path_processed_data}', 'wt'), indent=2) 476 | --------------------------------------------------------------------------------