├── .gitignore ├── images └── fig1.png ├── models ├── model-tune.yaml ├── __init__.py ├── modules │ └── encoder.py ├── auto_attention.py └── emo_net.py ├── preprocess ├── make_lst.sh ├── check_feat_dim.py ├── check_single_feat_dim.py ├── mer2025_base.py └── mer2024_base.py ├── toolkit └── utils │ ├── read_files.py │ ├── eval.py │ ├── draw_process.py │ ├── functions.py │ ├── metric.py │ └── loss.py ├── feature_extraction ├── run.sh ├── config.py ├── audio │ ├── split_audio.py │ └── extract_audio_huggingface.py ├── visual │ ├── split_video.py │ ├── util.py │ ├── extract_openface.py │ └── extract_vision_huggingface.py └── text │ ├── split_asr.py │ └── extract_text_huggingface.py ├── run_eval_unimodal.sh ├── eval_avg_performance.py ├── environment.yml ├── prompt.txt ├── dataloader └── dataloader.py ├── README.md ├── config.py ├── test.py ├── LICENSE └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | saved/ 2 | __pycache__/ 3 | .DS_Store -------------------------------------------------------------------------------- /images/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuyjan/MER2025-MRAC25/HEAD/images/fig1.png -------------------------------------------------------------------------------- /models/model-tune.yaml: -------------------------------------------------------------------------------- 1 | auto_attention: 2 | train_input_mode: input 3 | feat_type: utt 4 | hidden_dim: 256 5 | dropout: 0.2 6 | grad_clip: -1.0 7 | -------------------------------------------------------------------------------- /preprocess/make_lst.sh: -------------------------------------------------------------------------------- 1 | python mer2025_base.py --seed 0 2 | python mer2025_base.py --seed 1 3 | python mer2025_base.py --seed 2 4 | python mer2025_base.py --seed 3 5 | python mer2025_base.py --seed 4 -------------------------------------------------------------------------------- /preprocess/check_feat_dim.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from glob import glob 4 | 5 | import numpy as np 6 | 7 | if __name__ == "__main__": 8 | feat_root = "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/" 9 | flst = glob(osp.join(feat_root, "*")) 10 | 11 | for feat_dir in flst: 12 | feat_name = osp.basename(feat_dir) 13 | fpath = glob(osp.join(feat_dir, "*"))[0] 14 | feat_dim = np.load(fpath).shape[0] 15 | 16 | print("{}, dim={}".format(feat_name, feat_dim)) 17 | -------------------------------------------------------------------------------- /toolkit/utils/read_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import random 5 | import numpy as np 6 | import pandas as pd 7 | 8 | 9 | # 功能3:从csv中读取特定的key对应的值 10 | def func_read_key_from_csv(csv_path, key): 11 | values = [] 12 | df = pd.read_csv(csv_path) 13 | for _, row in df.iterrows(): 14 | if key not in row: 15 | values.append("") 16 | else: 17 | value = row[key] 18 | if pd.isna(value): 19 | value = "" 20 | values.append(value) 21 | return values 22 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .auto_attention import Auto_Attention 4 | 5 | 6 | class get_models(nn.Module): 7 | def __init__(self, args): 8 | super(get_models, self).__init__() 9 | # misa/mmim在有些参数配置下会存在梯度爆炸的风险 10 | # tfn 显存占比比较高 11 | 12 | MODEL_MAP = { 13 | # 特征压缩到句子级再处理,所以支持 utt/align/unalign 14 | "auto_attention": Auto_Attention, 15 | } 16 | self.model = MODEL_MAP[args.model](args) 17 | 18 | def forward(self, batch): 19 | return self.model(batch) 20 | -------------------------------------------------------------------------------- /toolkit/utils/eval.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from sklearn.metrics import mean_squared_error 4 | from sklearn.metrics import f1_score, accuracy_score 5 | 6 | 7 | def calculate_results(emo_probs=[], emo_labels=[], val_preds=[], val_labels=[]): 8 | 9 | emo_preds = np.argmax(emo_probs, 1) 10 | emo_accuracy = accuracy_score(emo_labels, emo_preds) 11 | emo_fscore = f1_score(emo_labels, emo_preds, average="weighted") 12 | 13 | results = { 14 | "emoprobs": emo_probs, 15 | "emolabels": emo_labels, 16 | "emoacc": emo_accuracy, 17 | "emofscore": emo_fscore, 18 | } 19 | outputs = f"f1:{emo_fscore:.4f}_acc:{emo_accuracy:.4f}" 20 | return results, outputs 21 | -------------------------------------------------------------------------------- /preprocess/check_single_feat_dim.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from glob import glob 4 | 5 | import numpy as np 6 | 7 | if __name__ == "__main__": 8 | feat_root = "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/InternVL_2_5_HiCo_R16-UTT" 9 | npy_files = glob(osp.join(feat_root, "*.npy")) 10 | 11 | print(f"Found {len(npy_files)} npy files in {feat_root}") 12 | print("-" * 50) 13 | 14 | for npy_file in npy_files: 15 | file_name = osp.basename(npy_file) 16 | try: 17 | data = np.load(npy_file) 18 | print(f"{file_name}: shape={data.shape}") 19 | except Exception as e: 20 | print(f"{file_name}: Error loading file - {e}") 21 | -------------------------------------------------------------------------------- /toolkit/utils/draw_process.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import os 3 | import os.path as osp 4 | 5 | 6 | def draw_loss(epoch, train_loss, valid_loss, save_path): 7 | plt.figure() 8 | plt.plot(epoch, train_loss, label="train loss") 9 | plt.plot(epoch, valid_loss, label="valid loss") 10 | plt.legend() 11 | plt.savefig( 12 | save_path, 13 | bbox_inches="tight", 14 | dpi=500, 15 | ) 16 | plt.close() 17 | return 18 | 19 | 20 | def draw_metric(epoch, metric, key, save_path): 21 | plt.figure() 22 | plt.plot(epoch, metric, label="valid {}".format(key)) 23 | plt.legend() 24 | plt.savefig( 25 | save_path, 26 | bbox_inches="tight", 27 | dpi=500, 28 | ) 29 | plt.close() 30 | -------------------------------------------------------------------------------- /toolkit/utils/functions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | # config -> args [只把config里存在,但是args中不存在或者为None的部分赋值] 5 | def merge_args_config(args, config): 6 | args_dic = vars(args) # convert to map version 7 | for key in config: 8 | if key not in args_dic or args_dic[key] is None: 9 | args_dic[key] = config[key] 10 | args_new = argparse.Namespace(**args_dic) # change to namespace 11 | return args_new 12 | 13 | 14 | def func_update_storage(inputs, prefix, outputs): 15 | for key in inputs: 16 | val = inputs[key] 17 | # update key and value 18 | newkey = f"{prefix}_{key}" 19 | newval = val 20 | # store into outputs 21 | assert newkey not in outputs 22 | outputs[newkey] = newval 23 | -------------------------------------------------------------------------------- /toolkit/utils/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.metrics import mean_squared_error 4 | from sklearn.metrics import f1_score, accuracy_score 5 | 6 | 7 | # 综合维度和离散的评价指标 8 | def overall_metric(emo_fscore, val_mse): 9 | final_score = emo_fscore - val_mse * 0.25 10 | return final_score 11 | 12 | 13 | # 只返回 metric 值,用于模型筛选 14 | def gain_metric_from_results(eval_results, metric_name="emoval"): 15 | 16 | if metric_name == "emoval": 17 | fscore = eval_results["emofscore"] 18 | valmse = eval_results["valmse"] 19 | overall = overall_metric(fscore, valmse) 20 | sort_metric = overall 21 | elif metric_name == "emo": 22 | fscore = eval_results["emofscore"] 23 | sort_metric = fscore 24 | elif metric_name == "val": 25 | valmse = eval_results["valmse"] 26 | sort_metric = -valmse 27 | elif metric_name == "loss": 28 | loss = eval_results["loss"] 29 | sort_metric = -loss 30 | 31 | return sort_metric 32 | -------------------------------------------------------------------------------- /feature_extraction/run.sh: -------------------------------------------------------------------------------- 1 | 2 | project_dir="/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/" 3 | 4 | # --------------- 视觉预处理 5 | # --- 提取人脸 6 | # cd ./visual 7 | # python extract_openface.py --dataset=$project_dir --type="videoOne" 8 | # cd ../ 9 | # # 10 | # # --- 提取视觉特征 11 | # cd ./visual 12 | # python -u extract_vision_huggingface.py --dataset=$project_dir --feature_level='UTTERANCE' --model_name='clip-vit-large-patch14' --gpu=0 13 | # cd ../ 14 | 15 | # # --------------- 音频预处理 16 | # # --- 分离音频 17 | # cd ./audio 18 | # python split_audio.py --dataset=$project_dir 19 | # cd ../ 20 | # # 21 | # # --- 提取音频特征 22 | # cd ./audio 23 | # python -u extract_audio_huggingface.py --dataset=$project_dir --feature_level='UTTERANCE' --model_name='chinese-hubert-large' --gpu=0 24 | # cd ../ 25 | 26 | # # --------------- 文本预处理 27 | # #--- 获取文本 28 | # cd ./text 29 | # python split_asr.py --dataset=$project_dir 30 | # cd ../ 31 | # # 32 | # # --- 提取文本特征 33 | cd ./text 34 | python extract_text_huggingface.py --dataset=$project_dir --feature_level='UTTERANCE' --model_name='bloom-7b1' --gpu=0 35 | cd ../ 36 | 37 | -------------------------------------------------------------------------------- /run_eval_unimodal.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | # 定义feat和seed的取值范围 4 | feats=("InternVL_2_5_HiCo_R16-UTT") 5 | seeds=(0 1 2 3 4) 6 | 7 | # 外层循环遍历feat 8 | for feat in "${feats[@]}"; do 9 | # 内层循环遍历seed 10 | for seed in "${seeds[@]}"; do 11 | echo "Running with seed=$seed, feat=$feat" 12 | 13 | # 执行命令 14 | python train.py --seed $seed \ 15 | --dataset "seed${seed}" \ 16 | --emo_rule "MER" \ 17 | --save_model \ 18 | --save_root "./saved/${feat}_3" \ 19 | --feat "['${feat}', '${feat}', '${feat}']" \ 20 | --lr 0.0001 \ 21 | --batch_size 512 \ 22 | --num_workers 4 \ 23 | --epochs 80 \ 24 | 25 | # 检查命令是否执行成功 26 | if [ $? -eq 0 ]; then 27 | echo "Command executed successfully for seed=$seed, feat=$feat" 28 | else 29 | echo "Error executing command for seed=$seed, feat=$feat" 30 | exit 1 31 | fi 32 | done 33 | done -------------------------------------------------------------------------------- /feature_extraction/config.py: -------------------------------------------------------------------------------- 1 | # *_*coding:utf-8 *_* 2 | import os 3 | import sys 4 | import socket 5 | import os.path as osp 6 | 7 | ############ global path ############## 8 | # PATH_TO_PROJECT = ( 9 | # "/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/" 10 | # ) 11 | # PATH_TO_VIDEO = osp.join(PATH_TO_PROJECT, "video") # 文件的初始文件夹 12 | 13 | # PATH_TO_FEATURES = osp.join(PATH_TO_PROJECT, "features") # 剪辑片段的特征的保存文件夹 14 | 15 | 16 | # PATH_TO_RAW_FACE_Win = PATH_TO_VIDEO 17 | # PATH_TO_FEATURES_Win = PATH_TO_FEATURES 18 | 19 | ############ Models ############## 20 | 21 | # pre-trained models, including supervised and unsupervised 22 | PATH_TO_PRETRAINED_MODELS = "/sda/xyy/mer/MERTools/MER2024/tools" 23 | PATH_TO_OPENSMILE = "/sda/xyy/mer/MERTools/MER2024/tools/opensmile-2.3.0/" 24 | PATH_TO_FFMPEG = "/sda/xyy/mer/MERTools/MER2024/tools/ffmpeg-4.4.1-i686-static/ffmpeg" 25 | PATH_TO_WENET = ( 26 | "/sda/xyy/mer/MERTools/MER2024/tools/wenet/wenetspeech_u2pp_conformer_libtorch" 27 | ) 28 | 29 | 30 | PATH_TO_OPENFACE_Win = "/sda/xyy/mer/tools/OpenFace_2.2.0" 31 | 32 | PATH_TO_FFMPEG_Win = ( 33 | "/sda/xyy/mer/MERTools/MER2024/tools/ffmpeg-3.4.1-win32-static/bin/ffmpeg" 34 | ) 35 | -------------------------------------------------------------------------------- /eval_avg_performance.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | if __name__ == "__main__": 7 | seeds = [0, 1, 2, 3, 4] 8 | feats = [ 9 | "InternVL_2_5_HiCo_R16-UTT", 10 | ] 11 | # feats = [ 12 | # "chinese-hubert-large-UTT", 13 | # "Qwen2-Audio-7B-UTT", 14 | # "chinese-hubert-base-UTT", 15 | # "whisper-large-v2-UTT", 16 | # "chinese-wav2vec2-large-UTT", 17 | # "chinese-wav2vec2-base-UTT", 18 | # "wavlm-base-UTT", 19 | # ] 20 | # feats = [ 21 | # "chinese-roberta-wwm-ext-large-UTT", 22 | # "chinese-roberta-wwm-ext-UTT", 23 | # "chinese-macbert-large-UTT", 24 | # "chinese-macbert-base-UTT", 25 | # "bloom-7b1-UTT", 26 | # ] 27 | 28 | for feat in feats: 29 | avg_acc = 0 30 | avg_fscore = 0 31 | cnt = 0 32 | for seed in seeds: 33 | fdir = osp.join( 34 | "./saved", feat + "_3", "seed" + str(seed), "auto_attention", "run" 35 | ) 36 | 37 | res = np.load( 38 | osp.join(fdir, "best_valid_results.npy"), allow_pickle=True 39 | ).item() 40 | 41 | avg_acc += res["eval_emoacc"] 42 | avg_fscore += res["eval_emofscore"] 43 | cnt += 1 44 | 45 | avg_acc /= cnt 46 | avg_fscore /= cnt 47 | print("### {}".format(feat + "_3")) 48 | print("acc={:.4f}, fscore={:.4f}".format(avg_acc, avg_fscore)) 49 | print("") 50 | 51 | a = 1 52 | -------------------------------------------------------------------------------- /feature_extraction/audio/split_audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import glob 4 | import sys 5 | 6 | current_file_path = os.path.abspath(__file__) 7 | sys.path.append(os.path.dirname(os.path.dirname(current_file_path))) 8 | import config 9 | import argparse 10 | 11 | 12 | def func_split_audio_from_video_16k(video_root, save_root): 13 | if not os.path.exists(save_root): 14 | os.makedirs(save_root) 15 | for video_path in tqdm.tqdm(glob.glob(video_root + "/*")): 16 | videoname = os.path.basename(video_path)[:-4] 17 | audio_path = os.path.join(save_root, videoname + ".wav") 18 | if os.path.exists(audio_path): 19 | continue 20 | cmd = "%s -loglevel quiet -y -i %s -ar 16000 -ac 1 %s" % ( 21 | config.PATH_TO_FFMPEG, 22 | video_path, 23 | audio_path, 24 | ) # linux 25 | # cmd = "%s -loglevel quiet -y -i %s -ar 16000 -ac 1 %s" %(config.PATH_TO_FFMPEG_Win, video_path, audio_path) # windows 26 | os.system(cmd) 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument( 32 | "--dataset", 33 | type=str, 34 | default="/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/", 35 | help="file name", 36 | ) 37 | args = parser.parse_args() 38 | 39 | dataset = args.dataset 40 | 41 | video_root = os.path.join(dataset, "video") 42 | save_root = os.path.join(dataset, "audio") 43 | func_split_audio_from_video_16k(video_root, save_root) 44 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mertools 2 | channels: 3 | - pytorch 4 | - defaults 5 | - anaconda 6 | dependencies: 7 | - python=3.9 8 | - cudatoolkit 9 | - pip 10 | - pytorch=1.12.1 11 | - pytorch-mutex=1.0=cuda 12 | - torchaudio=0.12.1 13 | - torchvision=0.13.1 14 | 15 | - pip: 16 | - accelerate==0.16.0 17 | - aiohttp==3.8.4 18 | - aiosignal==1.3.1 19 | - async-timeout==4.0.2 20 | - attrs==22.2.0 21 | - bitsandbytes==0.37.0 22 | - cchardet==2.1.7 23 | - chardet==5.1.0 24 | - contourpy==1.0.7 25 | - cycler==0.11.0 26 | - filelock==3.9.0 27 | - fonttools==4.38.0 28 | - frozenlist==1.3.3 29 | - huggingface-hub==0.13.4 30 | - importlib-resources==5.12.0 31 | - kiwisolver==1.4.4 32 | - matplotlib==3.7.0 33 | - multidict==6.0.4 34 | - openai==0.27.0 35 | - packaging==23.0 36 | - psutil==5.9.4 37 | - pycocotools==2.0.6 38 | - pyparsing==3.0.9 39 | - python-dateutil==2.8.2 40 | - pyyaml==6.0 41 | - regex==2022.10.31 42 | - tokenizers==0.13.2 43 | - tqdm==4.64.1 44 | - transformers==4.28.0 45 | - timm==0.6.13 46 | - spacy==3.5.1 47 | - webdataset==0.2.48 48 | - scikit-learn==1.2.2 49 | - scipy==1.10.1 50 | - yarl==1.8.2 51 | - zipp==3.14.0 52 | - omegaconf==2.3.0 53 | - opencv-python==4.7.0.72 54 | - iopath==0.1.10 55 | - decord==0.6.0 56 | - tenacity==8.2.2 57 | - peft 58 | - pycocoevalcap 59 | - sentence-transformers 60 | - umap-learn 61 | - notebook 62 | - gradio==3.24.1 63 | - gradio-client==0.0.8 64 | - wandb 65 | - einops 66 | - SentencePiece 67 | - ftfy 68 | - thop 69 | - pytorchvideo==0.1.5 -------------------------------------------------------------------------------- /feature_extraction/visual/split_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import cv2 4 | from moviepy.editor import VideoFileClip 5 | import argparse 6 | import sys 7 | 8 | current_file_path = os.path.abspath(__file__) 9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))) 10 | import config 11 | 12 | if __name__ == "__main__": 13 | split_t = config.VIDEO_SPLIT_T # 分割的间断时长 14 | duration_t = config.VIDEO_DURATION_T # 每段分割的持续时长 15 | 16 | # ----------- 需要指定的路径 17 | video_dir = config.PATH_TO_VIDEO # 文件的初始文件夹 18 | save_basedir = config.PATH_TO_VIDEO_CLIP # 剪辑片段的保存文件夹 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--dataset", 23 | type=str, 24 | default="test", 25 | help="file name", 26 | ) 27 | args = parser.parse_args() 28 | 29 | ftitle = args.dataset 30 | 31 | # 视频文件的路径 32 | video_path = osp.join(video_dir, ftitle + ".mp4") 33 | 34 | # 保存路径 35 | output_folder = osp.join(save_basedir, ftitle) 36 | if not osp.exists(output_folder): 37 | os.makedirs(output_folder) 38 | 39 | # 加载视频 40 | clip = VideoFileClip(video_path) 41 | 42 | # 视频总时长 43 | duration = clip.duration 44 | 45 | # 分割视频的起始时间 46 | start_time = 0 47 | 48 | # 循环分割视频,直到视频结束 49 | cnt = 0 50 | for i in range(0, int(duration // split_t)): 51 | # 计算结束时间,确保不超过视频总时长 52 | end_time = min(start_time + duration_t, duration) 53 | 54 | # 创建子剪辑 55 | subclip = clip.subclip(start_time, end_time) 56 | 57 | # 生成输出文件名 58 | output_filename = os.path.join(output_folder, f"clip_{cnt}.mp4") 59 | 60 | # 保存子剪辑 61 | subclip.write_videofile(output_filename, codec="libx264") 62 | 63 | # 更新起始时间 64 | start_time += split_t 65 | print(f"{cnt} done!") 66 | 67 | cnt += 1 68 | 69 | # 释放资源 70 | clip.close() 71 | 72 | print("视频分割完成") 73 | -------------------------------------------------------------------------------- /models/modules/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ref paper: Tensor Fusion Network for Multimodal Sentiment Analysis 3 | Ref url: https://github.com/Justin1904/TensorFusionNetworks 4 | """ 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | ## 这两个模块都是用在 TFN 中的 (video|audio) 9 | class MLPEncoder(nn.Module): 10 | ''' 11 | The subnetwork that is used in TFN for video and audio in the pre-fusion stage 12 | ''' 13 | 14 | def __init__(self, in_size, hidden_size, dropout): 15 | ''' 16 | Args: 17 | in_size: input dimension 18 | hidden_size: hidden layer dimension 19 | dropout: dropout probability 20 | Output: 21 | (return value in forward) a tensor of shape (batch_size, hidden_size) 22 | ''' 23 | super(MLPEncoder, self).__init__() 24 | # self.norm = nn.BatchNorm1d(in_size) 25 | self.drop = nn.Dropout(p=dropout) 26 | self.linear_1 = nn.Linear(in_size, hidden_size) 27 | self.linear_2 = nn.Linear(hidden_size, hidden_size) 28 | self.linear_3 = nn.Linear(hidden_size, hidden_size) 29 | 30 | def forward(self, x): 31 | ''' 32 | Args: 33 | x: tensor of shape (batch_size, in_size) 34 | ''' 35 | # normed = self.norm(x) 36 | dropped = self.drop(x) 37 | y_1 = F.relu(self.linear_1(dropped)) 38 | y_2 = F.relu(self.linear_2(y_1)) 39 | y_3 = F.relu(self.linear_3(y_2)) 40 | 41 | return y_3 42 | 43 | 44 | # TFN 中的文本编码,额外需要lstm 操作 [感觉是audio|video] 45 | class LSTMEncoder(nn.Module): 46 | ''' 47 | The LSTM-based subnetwork that is used in TFN for text 48 | ''' 49 | 50 | def __init__(self, in_size, hidden_size, dropout, num_layers=1, bidirectional=False): 51 | 52 | super(LSTMEncoder, self).__init__() 53 | 54 | if num_layers == 1: 55 | rnn_dropout = 0.0 56 | else: 57 | rnn_dropout = dropout 58 | 59 | self.rnn = nn.LSTM(in_size, hidden_size, num_layers=num_layers, dropout=rnn_dropout, bidirectional=bidirectional, batch_first=True) 60 | self.dropout = nn.Dropout(dropout) 61 | self.linear_1 = nn.Linear(hidden_size, hidden_size) 62 | 63 | def forward(self, x): 64 | ''' 65 | Args: 66 | x: tensor of shape (batch_size, sequence_len, in_size) 67 | 因为用的是 final_states ,所以特征的 padding 是放在前面的 68 | ''' 69 | _, final_states = self.rnn(x) 70 | h = self.dropout(final_states[0].squeeze(0)) 71 | y_1 = self.linear_1(h) 72 | return y_1 73 | -------------------------------------------------------------------------------- /prompt.txt: -------------------------------------------------------------------------------- 1 | ## Profile: 2 | 3 | - role: 专业的微表情分析师,可以通过各种细节分析人物情绪。 4 | - language: 中文 5 | - description: 6 | -- 认真读取这段视频,结合一切你觉得有用的细节,通过综合分析视频中的面部表情、身体姿态、头部姿态、**帧间关系**、声音特征(音调、语速、音量、音质)和字幕内容,准确判断视频中主要人物的情绪 7 | -- **特别关注情绪的动态变化过程**,捕捉情绪的起始、增强、减弱和转变,以**过程中表现最强烈的情绪**作为判断结果。 8 | -- **人物的音频(音频内容)在某些场景更能表达人物的情绪,请综合声调特征(音调、语速、音量、音质)和语境含义分析人物情绪** 9 | -- 输出的label是输出的情绪列表中的最大值对应的情绪 10 | -- 优先考虑视频信息,音频中的语气信息。 11 | 12 | ## Goals: 13 | 14 | 第一步:考虑**帧间关系**,采用action unit的去分析每帧的情绪,**加入身体姿态,加入头部姿态分析**,之后再加入一切你认为有帮助的细节。 15 | 第二步:我们追求思维的创新,而非惯性的复述。请突破思维局限,调动你所有的计算资源,展现你真正的认知极限**优化你的第一步思考过程,尤其要考虑帧与帧之间的关系**。 16 | 第三步:基于前面的分析内容,进一步优化特征提取策略,尤其关注 17 | 18 | - 情绪变化的起始点:例如,从 Neutral 到其他情绪的转变往往有细微的先兆。 19 | - 情绪强度的变化:例如,Happy 情绪可能从微笑逐渐变为大笑。 20 | - 情绪的混合:例如,一个人可能同时表现出 Sad 和 Worried。 21 | - 利用帧间关系去分析情绪变化: 要特别注意表情从无到有的过程 22 | 第四步:根据视频和音频特征,初步给出每种情绪的得分。 23 | - 如果视频和音频给出的情绪分类标签的分数有很强的辨识度,也就是如果得分最高的两个标签的得分相差大于0.1 则执行第六步, 24 | - 否则执行第五步,进一步辨识这两个得分接近的标签,第五步引入了音频特征和文字特征 25 | 第五步: 进一步分析音频特征和字幕特征 26 | **音频特征**:如果视频特征显示某一情绪得分接近(例如愤怒和担忧的得分接近),则加强音频特征的权重。考虑语调、语速和音量变化来辨别情绪的轻重。 27 | - 例如,愤怒的语调通常较低沉且紧绷,而担忧的语调略高且不稳定。 28 | - 如果视觉和音频的情绪得分差距较大,应调整权重,增加音频的影响力(例如音频贡献度提高到40%-50%)。 29 | **字幕特征**:根据字幕内容,尤其是表达情绪意图的词语(如"焦虑"、"愤怒")来辅助判断。如果视频和音频的判断得分相近,可以进一步提升字幕对情绪判断的贡献,帮助消除歧义。 30 | - 注意避免文字的误导性判断。例如,"你在干什么?" 可能是疑问;但是语气很愤怒时,也可以表示愤怒;表情开心时,可以解读为惊喜。 31 | 第六步:总结你的所有依据,给出最终结论,输出的推理过程应该保持简洁(极其重要)。 32 | 33 | 34 | ## Constraints: 35 | 36 | - 请用以下规范输出: 37 | - 情绪只能为["neutral", "angry", "happy", "sad", "worried", "surprise"]中的一个; 38 | - 对于视频中存在多个人物的情况,只需要分析出镜时间最长的主要人物的情绪; 39 | - 确保输出的json格式的正确性; 40 | - 语言平实、具体、简明、准确; 41 | - 优先选择具体名词替代抽象概念; 42 | - 保持段落简明(不超过5行); 43 | - 禁用文学化修辞; 44 | - 重点信息前置; 45 | - 复杂内容分点说明; 46 | - 保持口语化但不过度简化专业内容; 47 | - 确保信息准确前提下优先选择大众认知词汇; 48 | - **worried 的优先级相对较低**。仅在视频中有明显 Worried 表现且其他情绪不明显时预测为 Worried。 49 | - Sad/Surprise/Neutral 优先级:在这些情绪与 Worried 接近时,优先考虑 Sad/Surprise/Neutral。 50 | 51 | ## OutputFormat: 52 | 53 | { 54 | "推理过程": "主人公的面部表情和语气都很愉快,所以我认为他是开心的, 但是他的眼神有点忧郁,所以我认为他也有点伤心, 但是他的语气很平静,所以我认为他不是很生气", 55 | "情绪": { 56 | "neutral": 0.0, 57 | "angry": 0.0, 58 | "happy": 0.0, 59 | "sad": 0.0, 60 | "worried": 0.0, 61 | "surprise": 0.0 62 | }, 63 | "label": "happy", 64 | "贡献度":{ 65 | "视频": "0%", 66 | "音频": "0%", 67 | "字幕": "0%" 68 | } 69 | 70 | } 71 | 72 | 73 | ## Example: 74 | 75 | { 76 | "推理过程": "主人公的面部表情和语气都很愉快,所以我认为他是开心的, 但是他的眼神有点忧郁,所以我认为他也有点伤心, 但是他的语气很平静,所以我认为他不是很生气", 77 | "情绪": { 78 | "neutral": 0.3, 79 | "angry": 0.01, 80 | "happy": 0.8, 81 | "sad": 0.01, 82 | "worried": 0.1, 83 | "surprise": 0.4 84 | }, 85 | "label": "happy", 86 | "贡献度":{ 87 | "视频": "40%", 88 | "音频": "30%", 89 | "字幕": "30%" 90 | } 91 | } -------------------------------------------------------------------------------- /dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Description: build [train / valid / test] dataloader 3 | Author: Xiongjun Guan 4 | Date: 2023-03-01 19:41:05 5 | version: 0.0.1 6 | LastEditors: Xiongjun Guan 7 | LastEditTime: 2023-03-20 21:01:47 8 | 9 | Copyright (C) 2023 by Xiongjun Guan, Tsinghua University. All rights reserved. 10 | """ 11 | 12 | import copy 13 | import logging 14 | import os.path as osp 15 | import random 16 | 17 | import cv2 18 | import numpy as np 19 | from scipy import signal 20 | from torch.utils.data import DataLoader, Dataset 21 | 22 | 23 | class load_dataset_train(Dataset): 24 | 25 | def __init__( 26 | self, 27 | names, 28 | emo_labels, 29 | val_labels, 30 | emo_rule, 31 | feat_roots, 32 | ): 33 | self.names = names 34 | self.emo_labels = emo_labels 35 | self.val_labels = val_labels 36 | self.feat_roots = feat_roots 37 | 38 | self.emo2idx_mer, self.idx2emo_mer = {}, {} 39 | for ii, emo in enumerate(emo_rule): 40 | self.emo2idx_mer[emo] = ii 41 | for ii, emo in enumerate(emo_rule): 42 | self.idx2emo_mer[ii] = emo 43 | 44 | def __len__(self): 45 | return len(self.names) 46 | 47 | def __getitem__(self, idx): 48 | emo = self.emo2idx_mer[self.emo_labels[idx]] 49 | val = self.val_labels[idx] 50 | name = self.names[idx] 51 | 52 | inputs = {} 53 | idx = 0 54 | for feat_root in self.feat_roots: 55 | ftitle = osp.basename(name).split(".")[0] 56 | inputs["feat{}".format(idx)] = ( 57 | np.load(osp.join(feat_root, ftitle + ".npy")) 58 | .squeeze() 59 | .astype(np.float32) 60 | ) 61 | idx += 1 62 | 63 | return inputs, emo, val, name 64 | 65 | 66 | def get_dataloader_train( 67 | names, 68 | emo_labels, 69 | val_labels, 70 | emo_rule, 71 | feat_roots, 72 | batch_size=1, 73 | num_workers=4, 74 | shuffle=True, 75 | ): 76 | # Create dataset 77 | try: 78 | dataset = load_dataset_train( 79 | names=names, 80 | emo_labels=emo_labels, 81 | val_labels=val_labels, 82 | emo_rule=emo_rule, 83 | feat_roots=feat_roots, 84 | ) 85 | except Exception as e: 86 | logging.error("Error in DataLoader: ", repr(e)) 87 | return 88 | 89 | train_loader = DataLoader( 90 | dataset, 91 | batch_size=batch_size, 92 | shuffle=shuffle, 93 | num_workers=num_workers, 94 | pin_memory=True, 95 | ) 96 | logging.info(f"n_train:{len(dataset)}") 97 | 98 | return train_loader 99 | 100 | 101 | def get_dataloader_valid( 102 | names, 103 | emo_labels, 104 | val_labels, 105 | emo_rule, 106 | feat_roots, 107 | batch_size=1, 108 | num_workers=4, 109 | shuffle=False, 110 | ): 111 | # Create dataset 112 | try: 113 | dataset = load_dataset_train( 114 | names=names, 115 | emo_labels=emo_labels, 116 | val_labels=val_labels, 117 | emo_rule=emo_rule, 118 | feat_roots=feat_roots, 119 | ) 120 | except Exception as e: 121 | logging.error("Error in DataLoader: ", repr(e)) 122 | return 123 | 124 | valid_loader = DataLoader( 125 | dataset, 126 | batch_size=batch_size, 127 | shuffle=shuffle, 128 | num_workers=num_workers, 129 | pin_memory=True, 130 | ) 131 | logging.info(f"n_valid:{len(dataset)}") 132 | 133 | return valid_loader 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # More Is Better: A MoE-Based Emotion Recognition Framework with Human Preference Alignment (MRAC @ ACM-MM 2025) 2 | 3 | [![License: Apache-2.0](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 4 | [![arXiv](https://img.shields.io/badge/arXiv-2508.06036-B21A1B)](https://arxiv.org/abs/2508.06036) 5 | [![GitHub Stars](https://img.shields.io/github/stars/zhuyjan/MER2025-MRAC25?style=social)](https://github.com/zhuyjan/MER2025-MRAC25/stargazers) 6 | 7 | This repository provides our runner-up solution for [MER2025-SEMI challenge](https://zeroqiaoba.github.io/MER2025-website/) at MRAC'25 workshop. If our project helps you, please give us a star ⭐ on GitHub to support us. 🙏🙏 8 | 9 | ## Overview 10 |

method

11 | 12 | We propose a comprehensive framework, grounded in the principle that "more is better," to construct a robust Mixture of Experts (MoE) emotion recognition system. Our approach integrates a diverse range of input modalities as independent experts, including novel signals such as knowledge from large Vision-Language Models (VLMs) and temporal Action Unit (AU) information. 13 | 14 | ## Setup 15 | Our environment setup is identical to that of [MERBench](https://github.com/zeroQiaoba/MERTools/tree/master/MERBench). 16 | ```bash 17 | conda env create -f environment.yml 18 | ``` 19 | 20 | ## Workflow 21 | 22 | **Preprocessing** 23 | 1. Extract features for each sample. Each sample should correspond to a `.npy` file with shape `(C,)`. 24 | You can refer to [MER2025 Track1](https://github.com/zeroQiaoba/MERTools/tree/master/MER2025/MER2025_Track1) for more details on feature extraction. 25 | (We also provide our Gemini prompt for reference — see `prompt.txt` for details.) 26 | 2. Run `./preprocess/check_feat_dim.py` to obtain the dimensionality of each type of feature. 27 | 3. Fill in the feature names and their corresponding dimensions in the modality-specific dictionaries in `./config.py`. 28 | 4. Run `./preprocess/mer2025_base.py` to split the dataset into training and validation sets. 29 | 30 | **Training** 31 | - Run `./train.py` to perform training and validation. 32 | - Run `./run_eval_unimodal.sh` to evaluate unimodal performance. (All three branches use the same input for fair comparison with the trimodal setting.) 33 | 34 | ## Training Information Format 35 | 36 | Saved as `.npy` files under `./train_lst/`, each containing a Python dictionary with the following structure: 37 | 38 | ```python 39 | train_info = { 40 | # List of video names, e.g., ["example1", ...] 41 | "names": split_train_names, 42 | # List of emotion labels, e.g., ["happy", ...] 43 | "emos": split_train_emos, 44 | # List of emotion valence values, e.g., [-10, ...] 45 | "vals": split_train_vals, 46 | } 47 | 48 | valid_info = { 49 | # List of video names, e.g., ["example1", ...] 50 | "names": split_valid_names, 51 | # List of emotion labels, e.g., ["happy", ...] 52 | "emos": split_valid_emos, 53 | # List of emotion valence values, e.g., [-10, ...] 54 | "vals": split_valid_vals, 55 | } 56 | ``` 57 | 58 | ## Citation 59 | If you find this work useful for your research, please give us a star and use the following BibTeX entry for citation. 60 | ``` 61 | @inproceedings{xie2025more, 62 | title={More is better: A moe-based emotion recognition framework with human preference alignment}, 63 | author={Xie, Jun and Zhu, Yingjian and Chen, Feng and Zhang, Zhenghao and Fan, Xiaohui and Yi, Hongzhu and Wang, Xinming and Yu, Chen and Bi, Yue and Zhao, Zhaoran and others}, 64 | booktitle={Proceedings of the 3rd International Workshop on Multimodal and Responsible Affective Computing}, 65 | pages={2--7}, 66 | year={2025} 67 | } 68 | ``` 69 | -------------------------------------------------------------------------------- /feature_extraction/visual/util.py: -------------------------------------------------------------------------------- 1 | # *_*coding:utf-8 *_* 2 | import os 3 | import re 4 | import pandas as pd 5 | import numpy as np 6 | import struct 7 | 8 | 9 | ## for OPENFACE 10 | ## reference: https://gist.github.com/btlorch/6d259bfe6b753a7a88490c0607f07ff8 11 | def read_hog(filename, batch_size=5000): 12 | """ 13 | Read HoG features file created by OpenFace. 14 | For each frame, OpenFace extracts 12 * 12 * 31 HoG features, i.e., num_features = 4464. These features are stored in row-major order. 15 | :param filename: path to .hog file created by OpenFace 16 | :param batch_size: how many rows to read at a time 17 | :return: is_valid, hog_features 18 | is_valid: ndarray of shape [num_frames] 19 | hog_features: ndarray of shape [num_frames, num_features] 20 | """ 21 | all_feature_vectors = [] 22 | with open(filename, "rb") as f: 23 | (num_cols,) = struct.unpack("i", f.read(4)) # 12 24 | (num_rows,) = struct.unpack("i", f.read(4)) # 12 25 | (num_channels,) = struct.unpack("i", f.read(4)) # 31 26 | 27 | # The first four bytes encode a boolean value whether the frame is valid 28 | num_features = 1 + num_rows * num_cols * num_channels 29 | feature_vector = struct.unpack( 30 | "{}f".format(num_features), f.read(num_features * 4) 31 | ) 32 | feature_vector = np.array(feature_vector).reshape( 33 | (1, num_features) 34 | ) # [1, 4464+1] 35 | all_feature_vectors.append(feature_vector) 36 | 37 | # Every frame contains a header of four float values: num_cols, num_rows, num_channels, is_valid 38 | num_floats_per_feature_vector = 4 + num_rows * num_cols * num_channels 39 | # Read in batches of given batch_size 40 | num_floats_to_read = num_floats_per_feature_vector * batch_size 41 | # Multiply by 4 because of float32 42 | num_bytes_to_read = num_floats_to_read * 4 43 | 44 | while True: 45 | bytes = f.read(num_bytes_to_read) 46 | # For comparison how many bytes were actually read 47 | num_bytes_read = len(bytes) 48 | assert ( 49 | num_bytes_read % 4 == 0 50 | ), "Number of bytes read does not match with float size" 51 | num_floats_read = num_bytes_read // 4 52 | assert ( 53 | num_floats_read % num_floats_per_feature_vector == 0 54 | ), "Number of bytes read does not match with feature vector size" 55 | num_feature_vectors_read = num_floats_read // num_floats_per_feature_vector 56 | 57 | feature_vectors = struct.unpack("{}f".format(num_floats_read), bytes) 58 | # Convert to array 59 | feature_vectors = np.array(feature_vectors).reshape( 60 | (num_feature_vectors_read, num_floats_per_feature_vector) 61 | ) 62 | # Discard the first three values in each row (num_cols, num_rows, num_channels) 63 | feature_vectors = feature_vectors[:, 3:] 64 | # Append to list of all feature vectors that have been read so far 65 | all_feature_vectors.append(feature_vectors) 66 | 67 | if num_bytes_read < num_bytes_to_read: 68 | break 69 | 70 | # Concatenate batches 71 | all_feature_vectors = np.concatenate(all_feature_vectors, axis=0) 72 | 73 | # Split into is-valid and feature vectors 74 | is_valid = all_feature_vectors[:, 0] 75 | feature_vectors = all_feature_vectors[:, 1:] 76 | 77 | return is_valid, feature_vectors 78 | 79 | 80 | ## for OPENFACE 81 | def read_csv(filename, startIdx): 82 | data = pd.read_csv(filename) 83 | all_feature_vectors = [] 84 | for index in data.index: 85 | features = np.array(data.iloc[index][startIdx:]) 86 | all_feature_vectors.append(features) 87 | all_feature_vectors = np.array(all_feature_vectors) 88 | return all_feature_vectors 89 | -------------------------------------------------------------------------------- /preprocess/mer2025_base.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import shutil 4 | import sys 5 | 6 | import pandas as pd 7 | 8 | sys.path.append("/mnt/public/gxj_2/EmoNet_Pro/") 9 | import argparse 10 | import os.path as osp 11 | import random 12 | from collections import Counter 13 | 14 | import numpy as np 15 | import torch 16 | from tqdm import tqdm 17 | 18 | import config 19 | 20 | 21 | def set_seed(seed=0): 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | np.random.seed(seed) 26 | random.seed(seed) 27 | torch.backends.cudnn.benchmark = False 28 | torch.backends.cudnn.deterministic = True 29 | 30 | 31 | # run -d toolkit/preprocess/mer2024.py 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument( 35 | "--seed", 36 | type=int, 37 | default=0, 38 | help="random seed", 39 | ) 40 | args = parser.parse_args() 41 | 42 | random_seed = args.seed 43 | set_seed(random_seed) 44 | print("### random seed: {}".format(random_seed)) 45 | 46 | # --- data config 47 | emo_rule = "MER" 48 | gt_path = "/mnt/public/share/data/MER2025/mer2025-dataset/track1_train_disdim.csv" 49 | 50 | # --- save config 51 | save_root = "/mnt/public/gxj_2/EmoNet_Pro/lst_train/mer25_train_val" 52 | save_name = "seed{}".format(random_seed) 53 | 54 | # --- load data and format 55 | mapping_rules = config.EMO_RULE[emo_rule] 56 | 57 | all_names = [] 58 | all_emos = [] 59 | all_vals = [] 60 | df = pd.read_csv(gt_path) 61 | for index, row in tqdm(df.iterrows(), total=len(df)): 62 | name = row["name"] 63 | discrete = row["discrete"] 64 | valence = row["valence"] 65 | 66 | # emo = mapping_rules.index(discrete) 67 | 68 | all_names.append(name) 69 | all_emos.append(discrete) 70 | all_vals.append(valence) 71 | 72 | counted_numbers = {} 73 | for emo in all_emos: 74 | if emo in counted_numbers: 75 | counted_numbers[emo] += 1 76 | else: 77 | counted_numbers[emo] = 1 78 | for emo in mapping_rules: 79 | print( 80 | "{}, num={}, percent={:.2f}%".format( 81 | emo, 82 | counted_numbers[emo], 83 | 100 * counted_numbers[emo] / len(all_emos), 84 | ) 85 | ) 86 | 87 | # --------------- split train & test --------------- 88 | split_ratio = 0.2 89 | whole_num = len(all_names) 90 | 91 | # gain indices for cross-validation 92 | indices = np.arange(whole_num) 93 | random.shuffle(indices) 94 | 95 | # split indices into 1-fold 96 | each_folder_num = int(whole_num * split_ratio) 97 | valid_idxs = indices[0:each_folder_num] 98 | train_idxs = indices[each_folder_num:] 99 | 100 | split_train_names = [] 101 | split_train_emos = [] 102 | split_train_vals = [] 103 | for idx in train_idxs: 104 | split_train_names.append(all_names[idx]) 105 | split_train_emos.append(all_emos[idx]) 106 | split_train_vals.append(all_vals[idx]) 107 | 108 | split_valid_names = [] 109 | split_valid_emos = [] 110 | split_valid_vals = [] 111 | for idx in valid_idxs: 112 | split_valid_names.append(all_names[idx]) 113 | split_valid_emos.append(all_emos[idx]) 114 | split_valid_vals.append(all_vals[idx]) 115 | 116 | train_info = { 117 | "names": split_train_names, 118 | "emos": split_train_emos, 119 | "vals": split_train_vals, 120 | } 121 | valid_info = { 122 | "names": split_valid_names, 123 | "emos": split_valid_emos, 124 | "vals": split_valid_vals, 125 | } 126 | 127 | print("----------- summary") 128 | print( 129 | "split_train: {}\nsplit_valid: {}".format( 130 | len(train_info["names"]), len(valid_info["names"]) 131 | ) 132 | ) 133 | 134 | save_path = os.path.join(save_root, save_name + ".npy") 135 | if not os.path.exists(save_root): 136 | os.makedirs(save_root) 137 | 138 | np.save( 139 | save_path, 140 | { 141 | "train_info": train_info, 142 | "valid_info": valid_info, 143 | }, 144 | ) 145 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # ----------------- path to train/test info 2 | PATH_TO_TRAIN_LST = "/mnt/public/gxj_2/EmoNet_Pro/lst_train/mer25_train_val" 3 | PATH_TO_TEST_LST = "/mnt/public/gxj_2/EmoNet_Pro/lst_test/" 4 | 5 | # ----------------- emo / index matching rules 6 | EMO_RULE = { 7 | "MER": ["neutral", "angry", "happy", "sad", "worried", "surprise"], 8 | "CREMA-D": ["neutral", "angry", "happy", "sad", "fear", "disgust"], 9 | "TESS": ["neutral", "angry", "happy", "sad", "fear", "disgust"], 10 | "RAVDESS": [ 11 | "neutral", 12 | "angry", 13 | "happy", 14 | "sad", 15 | "fear", 16 | "disgust", 17 | "surprised", 18 | "calm", 19 | ], 20 | } 21 | 22 | # ----------------- features can be used 23 | FEAT_VIDEO_DICT = { 24 | "senet50face_UTT": [ 25 | 512, 26 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/senet50face_UTT", 27 | ], 28 | "resnet50face_UTT": [ 29 | 512, 30 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/resnet50face_UTT", 31 | ], 32 | "clip-vit-large-patch14-UTT": [ 33 | 768, 34 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/clip-vit-large-patch14-UTT", 35 | ], 36 | "clip-vit-base-patch32-UTT": [ 37 | 512, 38 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/clip-vit-base-patch32-UTT", 39 | ], 40 | "videomae-large-UTT": [ 41 | 1024, 42 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/videomae-large-UTT", 43 | ], 44 | "videomae-base-UTT": [ 45 | 768, 46 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/videomae-base-UTT", 47 | ], 48 | "manet_UTT": [ 49 | 1024, 50 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/manet_UTT", 51 | ], 52 | "emonet_UTT": [ 53 | 256, 54 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/emonet_UTT", 55 | ], 56 | "dinov2-large-UTT": [ 57 | 1024, 58 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/dinov2-large-UTT", 59 | ], 60 | "InternVL_2_5_HiCo_R16-UTT": [ 61 | 4096, 62 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/InternVL_2_5_HiCo_R16-UTT", 63 | ], 64 | } 65 | 66 | FEAT_AUDIO_DICT = { 67 | "chinese-hubert-large-UTT": [ 68 | 1024, 69 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-hubert-large-UTT", 70 | ], 71 | "Qwen2-Audio-7B-UTT": [ 72 | 1280, 73 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/Qwen2-Audio-7B-UTT", 74 | ], 75 | "chinese-hubert-base-UTT": [ 76 | 768, 77 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-hubert-base-UTT", 78 | ], 79 | "whisper-large-v2-UTT": [ 80 | 1280, 81 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/whisper-large-v2-UTT", 82 | ], 83 | "chinese-wav2vec2-large-UTT": [ 84 | 1024, 85 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-wav2vec2-large-UTT", 86 | ], 87 | "chinese-wav2vec2-base-UTT": [ 88 | 768, 89 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-wav2vec2-base-UTT", 90 | ], 91 | "wavlm-base-UTT": [ 92 | 768, 93 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/wavlm-base-UTT", 94 | ], 95 | } 96 | 97 | FEAT_TEXT_DICT = { 98 | "chinese-roberta-wwm-ext-large-UTT": [ 99 | 1024, 100 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-roberta-wwm-ext-large-UTT", 101 | ], 102 | "chinese-roberta-wwm-ext-UTT": [ 103 | 768, 104 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-roberta-wwm-ext-UTT", 105 | ], 106 | "chinese-macbert-large-UTT": [ 107 | 1024, 108 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-macbert-large-UTT", 109 | ], 110 | "chinese-macbert-base-UTT": [ 111 | 768, 112 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-macbert-base-UTT", 113 | ], 114 | "bloom-7b1-UTT": [ 115 | 4096, 116 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/bloom-7b1-UTT", 117 | ], 118 | } 119 | 120 | MODEL_DIR_DICT = {} 121 | -------------------------------------------------------------------------------- /preprocess/mer2024_base.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import shutil 4 | import sys 5 | 6 | sys.path.append("/mnt/public/gxj_2/EmoNet_Pro/") 7 | import argparse 8 | import os.path as osp 9 | import random 10 | 11 | import numpy as np 12 | import torch 13 | 14 | import config 15 | # from toolkit.utils.read_files import * 16 | from toolkit.utils.read_files import func_read_key_from_csv 17 | 18 | 19 | def set_seed(seed=0): 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed) 24 | random.seed(seed) 25 | torch.backends.cudnn.benchmark = False 26 | torch.backends.cudnn.deterministic = True 27 | 28 | 29 | def read_data_into_lists(filename, emo_rule): 30 | sample_ids = [] 31 | numbers = [] 32 | emotions = [] 33 | float_values = [] 34 | 35 | with open(filename, "r") as file: 36 | for line in file: 37 | parts = line.strip().split() 38 | 39 | 40 | if parts[2] not in config.EMO_RULE[emo_rule]: 41 | continue 42 | 43 | sample_ids.append(parts[0]) 44 | numbers.append(int(parts[1])) 45 | emotions.append(parts[2]) 46 | float_values.append(float(parts[3])) 47 | 48 | return sample_ids, numbers, emotions, float_values 49 | 50 | 51 | def normalize_dataset_format( 52 | emo_rule, 53 | gt_path, 54 | video_root, 55 | feat_roots, 56 | ): 57 | 58 | all_names, _, all_emos, all_vals = read_data_into_lists( 59 | gt_path, emo_rule 60 | ) 61 | 62 | # --------------- split train & test --------------- 63 | split_ratio = 0.2 64 | whole_num = len(all_names) 65 | 66 | # gain indices for cross-validation 67 | indices = np.arange(whole_num) 68 | random.shuffle(indices) 69 | 70 | # split indices into 1-fold 71 | each_folder_num = int(whole_num * split_ratio) 72 | test_idxs = indices[0:each_folder_num] 73 | train_idxs = indices[each_folder_num:] 74 | 75 | split_train_names = [] 76 | split_train_emos = [] 77 | split_train_vals = [] 78 | for idx in train_idxs: 79 | name = all_names[idx] 80 | train_name = osp.join(video_root, "video", name) 81 | split_train_names.append(train_name) 82 | split_train_emos.append(all_emos[idx]) 83 | split_train_vals.append(all_vals[idx]) 84 | 85 | 86 | feat_dims = [] 87 | for feat_root in feat_roots: 88 | feat_dim = np.load(osp.join(feat_root,all_names[0]+".npy")).shape[0] 89 | feat_dims.append(feat_dim) 90 | 91 | split_test_names = [] 92 | split_test_emos = [] 93 | split_test_vals = [] 94 | for idx in test_idxs: 95 | name = all_names[idx] 96 | test_name = osp.join(video_root, name) 97 | split_test_names.append(test_name) 98 | split_test_emos.append(all_emos[idx]) 99 | split_test_vals.append(all_vals[idx]) 100 | 101 | cnt=1 102 | for dim,root in zip(feat_dims,feat_roots): 103 | print("----------- feat {}".format(cnt)) 104 | print("# root={}".format(root)) 105 | print("# dim={}".format(dim)) 106 | cnt+=1 107 | 108 | train_info = { 109 | "names": split_train_names, 110 | "emos": split_train_emos, 111 | "vals": split_train_vals, 112 | } 113 | valid_info = { 114 | "names": split_test_names, 115 | "emos": split_test_emos, 116 | "vals": split_test_vals, 117 | } 118 | 119 | return { 120 | "feat_dims": feat_dims, 121 | "feat_roots": feat_roots, 122 | "train_info": train_info, 123 | "valid_info": valid_info, 124 | } 125 | 126 | 127 | # run -d toolkit/preprocess/mer2024.py 128 | if __name__ == "__main__": 129 | random_seed = 1 130 | set_seed(random_seed) 131 | print("random seed: {}".format(random_seed)) 132 | 133 | # --- data config 134 | emo_rule = "MER" 135 | gt_path = "/mnt/public/share/data/Dataset/MER2024/MER2024-labeled.txt" 136 | video_root = "/mnt/public/share/data/Dataset/MER2024/video/" 137 | 138 | # --- feature config 139 | description = "mer24-train" 140 | feat_roots = ["/mnt/public/share/data/Dataset/MER2024/features/clip-vit-large-patch14-UTT/", 141 | "/mnt/public/share/data/Dataset/MER2024/features/chinese-hubert-large-UTT/", 142 | "/mnt/public/share/data/Dataset/MER2024/features/bloom-7b1-UTT/",] 143 | 144 | # --- save config 145 | save_root = "/mnt/public/gxj_2/EmoNet_Pro/lst_train/" 146 | save_name = "seed{}_MER24".format(random_seed) 147 | 148 | # --- run 149 | dataset_info = normalize_dataset_format( 150 | emo_rule, 151 | gt_path, 152 | video_root, 153 | feat_roots, 154 | ) 155 | 156 | feat_dims = dataset_info["feat_dims"] 157 | feat_roots = dataset_info["feat_roots"] 158 | train_info = dataset_info["train_info"] 159 | valid_info = dataset_info["valid_info"] 160 | 161 | train_info["vals"] = [-100] * len(train_info["vals"]) 162 | valid_info["vals"] = [-100] * len(valid_info["vals"]) 163 | 164 | save_path = os.path.join(save_root, save_name + ".npy") 165 | if not os.path.exists(save_root): 166 | os.makedirs(save_root) 167 | 168 | np.save( 169 | save_path, 170 | { 171 | "feat_dims": feat_dims, 172 | "feat_roots": feat_roots, 173 | "train_info": train_info, 174 | "valid_info": valid_info, 175 | }, 176 | ) 177 | 178 | print("----------- summary") 179 | print( 180 | "split_train: {}\nsplit_valid: {}".format( 181 | len(train_info["names"]), len(valid_info["names"]) 182 | ) 183 | ) 184 | -------------------------------------------------------------------------------- /models/auto_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .modules.encoder import LSTMEncoder, MLPEncoder 5 | 6 | 7 | class Auto_Attention(nn.Module): 8 | def __init__(self, args): 9 | super(Auto_Attention, self).__init__() 10 | feat_dims = args.feat_dims 11 | output_dim1 = args.output_dim1 12 | output_dim2 = args.output_dim2 13 | dropout = args.dropout 14 | hidden_dim = args.hidden_dim 15 | self.grad_clip = args.grad_clip 16 | self.feat_dims = feat_dims 17 | 18 | assert args.feat_type == "utt" 19 | 20 | assert len(feat_dims) <= 3 * 6 21 | if len(feat_dims) >= 1: 22 | self.encoder0 = MLPEncoder(feat_dims[0], hidden_dim, dropout) 23 | if len(feat_dims) >= 2: 24 | self.encoder1 = MLPEncoder(feat_dims[1], hidden_dim, dropout) 25 | if len(feat_dims) >= 3: 26 | self.encoder2 = MLPEncoder(feat_dims[2], hidden_dim, dropout) 27 | if len(feat_dims) >= 4: 28 | self.encoder3 = MLPEncoder(feat_dims[3], hidden_dim, dropout) 29 | if len(feat_dims) >= 5: 30 | self.encoder4 = MLPEncoder(feat_dims[4], hidden_dim, dropout) 31 | if len(feat_dims) >= 6: 32 | self.encoder5 = MLPEncoder(feat_dims[5], hidden_dim, dropout) 33 | if len(feat_dims) >= 7: 34 | self.encoder6 = MLPEncoder(feat_dims[6], hidden_dim, dropout) 35 | if len(feat_dims) >= 8: 36 | self.encoder7 = MLPEncoder(feat_dims[7], hidden_dim, dropout) 37 | if len(feat_dims) >= 9: 38 | self.encoder8 = MLPEncoder(feat_dims[8], hidden_dim, dropout) 39 | if len(feat_dims) >= 10: 40 | self.encoder9 = MLPEncoder(feat_dims[9], hidden_dim, dropout) 41 | if len(feat_dims) >= 11: 42 | self.encoder10 = MLPEncoder(feat_dims[10], hidden_dim, dropout) 43 | if len(feat_dims) >= 12: 44 | self.encoder11 = MLPEncoder(feat_dims[11], hidden_dim, dropout) 45 | if len(feat_dims) >= 13: 46 | self.encoder12 = MLPEncoder(feat_dims[12], hidden_dim, dropout) 47 | if len(feat_dims) >= 14: 48 | self.encoder13 = MLPEncoder(feat_dims[13], hidden_dim, dropout) 49 | if len(feat_dims) >= 15: 50 | self.encoder14 = MLPEncoder(feat_dims[14], hidden_dim, dropout) 51 | if len(feat_dims) >= 16: 52 | self.encoder15 = MLPEncoder(feat_dims[15], hidden_dim, dropout) 53 | if len(feat_dims) >= 17: 54 | self.encoder16 = MLPEncoder(feat_dims[16], hidden_dim, dropout) 55 | if len(feat_dims) >= 18: 56 | self.encoder17 = MLPEncoder(feat_dims[17], hidden_dim, dropout) 57 | 58 | self.attention_mlp = MLPEncoder( 59 | hidden_dim * len(feat_dims), hidden_dim, dropout 60 | ) 61 | 62 | self.fc_att = nn.Linear(hidden_dim, 3) 63 | self.fc_out_1 = nn.Linear(hidden_dim, output_dim1) 64 | self.fc_out_2 = nn.Linear(hidden_dim, output_dim2) 65 | 66 | def forward(self, batch): 67 | """ 68 | support feat_type: utt | frm-align | frm-unalign 69 | """ 70 | hiddens = [] 71 | if len(self.feat_dims) >= 1: 72 | hiddens.append(self.encoder0(batch[f"feat0"])) 73 | if len(self.feat_dims) >= 2: 74 | hiddens.append(self.encoder1(batch[f"feat1"])) 75 | if len(self.feat_dims) >= 3: 76 | hiddens.append(self.encoder2(batch[f"feat2"])) 77 | if len(self.feat_dims) >= 4: 78 | hiddens.append(self.encoder3(batch[f"feat3"])) 79 | if len(self.feat_dims) >= 5: 80 | hiddens.append(self.encoder4(batch[f"feat4"])) 81 | if len(self.feat_dims) >= 6: 82 | hiddens.append(self.encoder5(batch[f"feat5"])) 83 | if len(self.feat_dims) >= 7: 84 | hiddens.append(self.encoder6(batch[f"feat6"])) 85 | if len(self.feat_dims) >= 8: 86 | hiddens.append(self.encoder7(batch[f"feat7"])) 87 | if len(self.feat_dims) >= 9: 88 | hiddens.append(self.encoder8(batch[f"feat8"])) 89 | if len(self.feat_dims) >= 10: 90 | hiddens.append(self.encoder9(batch[f"feat9"])) 91 | if len(self.feat_dims) >= 11: 92 | hiddens.append(self.encoder10(batch[f"feat10"])) 93 | if len(self.feat_dims) >= 12: 94 | hiddens.append(self.encoder11(batch[f"feat11"])) 95 | if len(self.feat_dims) >= 13: 96 | hiddens.append(self.encoder12(batch[f"feat12"])) 97 | if len(self.feat_dims) >= 14: 98 | hiddens.append(self.encoder13(batch[f"feat13"])) 99 | if len(self.feat_dims) >= 15: 100 | hiddens.append(self.encoder14(batch[f"feat14"])) 101 | if len(self.feat_dims) >= 16: 102 | hiddens.append(self.encoder15(batch[f"feat15"])) 103 | if len(self.feat_dims) >= 17: 104 | hiddens.append(self.encoder16(batch[f"feat16"])) 105 | if len(self.feat_dims) >= 18: 106 | hiddens.append(self.encoder17(batch[f"feat17"])) 107 | 108 | multi_hidden1 = torch.cat(hiddens, dim=1) # [32, 384] 109 | attention = self.attention_mlp(multi_hidden1) 110 | attention = self.fc_att(attention) 111 | attention = torch.unsqueeze(attention, 2) # [32, 3, 1] 112 | 113 | multi_hidden2 = torch.stack(hiddens, dim=2) # [32, 128, 3] 114 | fused_feat = torch.matmul( 115 | multi_hidden2, attention 116 | ) # [32, 128, 3] * [32, 3, 1] = [32, 128, 1] 117 | 118 | features = fused_feat.squeeze(axis=2) # [32, 128] => 解决batch=1报错的问题 119 | emos_out = self.fc_out_1(features) 120 | vals_out = self.fc_out_2(features) 121 | interloss = torch.tensor(0).cuda() 122 | 123 | return features, emos_out, vals_out, interloss 124 | -------------------------------------------------------------------------------- /feature_extraction/text/split_asr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import glob 4 | import numpy as np 5 | import pandas as pd 6 | import sys 7 | import argparse 8 | 9 | current_file_path = os.path.abspath(__file__) 10 | sys.path.append(os.path.dirname(os.path.dirname(current_file_path))) 11 | import config 12 | 13 | 14 | # 功能3:从csv中读取特定的key对应的值 15 | def func_read_key_from_csv(csv_path, key): 16 | values = [] 17 | df = pd.read_csv(csv_path) 18 | for _, row in df.iterrows(): 19 | if key not in row: 20 | values.append("") 21 | else: 22 | value = row[key] 23 | if pd.isna(value): 24 | value = "" 25 | values.append(value) 26 | return values 27 | 28 | 29 | # names[ii] -> keys=name2key[names[ii]], containing keynames 30 | def func_write_key_to_csv(csv_path, names, name2key, keynames): 31 | ## specific case: only save names 32 | if len(name2key) == 0 or len(keynames) == 0: 33 | df = pd.DataFrame(data=names, columns=["name"]) 34 | df.to_csv(csv_path, index=False) 35 | return 36 | 37 | ## other cases: 38 | if isinstance(keynames, str): 39 | keynames = [keynames] 40 | assert isinstance(keynames, list) 41 | columns = ["name"] + keynames 42 | 43 | values = [] 44 | for name in names: 45 | value = name2key[name] 46 | values.append(value) 47 | values = np.array(values) 48 | # ensure keynames is mapped 49 | if len(values.shape) == 1: 50 | assert len(keynames) == 1 51 | else: 52 | assert values.shape[-1] == len(keynames) 53 | data = np.column_stack([names, values]) 54 | 55 | df = pd.DataFrame(data=data, columns=columns) 56 | df.to_csv(csv_path, index=False) 57 | 58 | 59 | # python main-asr.py generate_transcription_files_asr ./dataset-process/audio ./dataset-process/transcription.csv 60 | def generate_transcription_files_asr(audio_root, save_path): 61 | import torch 62 | 63 | # import wenetruntime as wenet # must load torch first 64 | import wenet 65 | 66 | # from paddlespeech.cli.text.infer import TextExecutor 67 | # text_punc = TextExecutor() 68 | # decoder = wenet.Decoder(config.PATH_TO_WENET, lang='chs') 69 | decoder = wenet.load_model(language="chinese") 70 | 71 | names = [] 72 | sentences = [] 73 | for audio_path in tqdm.tqdm(glob.glob(audio_root + "/*")): 74 | name = os.path.basename(audio_path)[:-4] 75 | # sentence = decoder.decode_wav(audio_path) 76 | # sentence = sentence.split('"')[5] 77 | sentence = decoder.transcribe(audio_path) 78 | sentence = sentence["text"] 79 | # if len(sentence) > 0: sentence = text_punc(text=sentence) 80 | names.append(name) 81 | sentences.append(sentence) 82 | 83 | ## write to csv file 84 | columns = ["name", "sentence"] 85 | data = np.column_stack([names, sentences]) 86 | df = pd.DataFrame(data=data, columns=columns) 87 | df[columns] = df[columns].astype(str) 88 | df.to_csv(save_path, index=False) 89 | 90 | 91 | # python main-asr.py refinement_transcription_files_asr(old_path, new_path) 92 | def refinement_transcription_files_asr(old_path, new_path): 93 | from paddlespeech.cli.text.infer import TextExecutor 94 | 95 | text_punc = TextExecutor() 96 | 97 | ## read 98 | names, sentences = [], [] 99 | df_label = pd.read_csv(old_path) 100 | for _, row in df_label.iterrows(): ## read for each row 101 | names.append(row["name"]) 102 | sentence = row["sentence"] 103 | if pd.isna(sentence): 104 | sentences.append("") 105 | else: 106 | sentence = text_punc(text=sentence) 107 | sentences.append(sentence) 108 | print(sentences[-1]) 109 | 110 | ## write 111 | columns = ["name", "chinese"] 112 | data = np.column_stack([names, sentences]) 113 | df = pd.DataFrame(data=data, columns=columns) 114 | df[columns] = df[columns].astype(str) 115 | df.to_csv(new_path, index=False) 116 | 117 | 118 | # python main-asr.py merge_trans_with_checked dataset/mer2024-dataset-process/transcription.csv dataset/mer2024-dataset/label-transcription.csv dataset/mer2024-dataset-process/transcription-merge.csv 119 | def merge_trans_with_checked(new_path, check_path, merge_path): 120 | 121 | # read new_path 7369 122 | name2new = {} 123 | names = func_read_key_from_csv(new_path, "name") 124 | trans = func_read_key_from_csv(new_path, "sentence") 125 | for name, tran in zip(names, trans): 126 | name2new[name] = tran 127 | print(f"new sample: {len(name2new)}") 128 | 129 | # read check_path 5030 130 | name2check = {} 131 | names = func_read_key_from_csv(check_path, "name") 132 | trans = func_read_key_from_csv(check_path, "chinese") 133 | for name, tran in zip(names, trans): 134 | name2check[name] = tran 135 | print(f"check sample: {len(name2check)}") 136 | 137 | # 生成新的merge结果 138 | name2merge = {} 139 | for name in name2new: 140 | if name in name2check: 141 | name2merge[name] = [name2check[name]] 142 | else: 143 | name2merge[name] = [name2new[name]] 144 | print(f"merge sample: {len(name2merge)}") 145 | 146 | # 存储 name2merge 147 | names = [name for name in name2merge] 148 | keynames = ["chinese"] 149 | func_write_key_to_csv(merge_path, names, name2merge, keynames) 150 | 151 | 152 | if __name__ == "__main__": 153 | parser = argparse.ArgumentParser() 154 | parser.add_argument( 155 | "--dataset", 156 | type=str, 157 | default="/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/", 158 | help="file name", 159 | ) 160 | args = parser.parse_args() 161 | 162 | dataset = args.dataset 163 | 164 | audio_root = os.path.join(dataset, "audio") 165 | save_root = os.path.join(dataset, "text") 166 | 167 | if not os.path.exists(save_root): 168 | os.makedirs(save_root) 169 | 170 | generate_transcription_files_asr( 171 | audio_root, 172 | os.path.join(save_root, "transcription-old.csv"), 173 | ) 174 | 175 | refinement_transcription_files_asr( 176 | os.path.join(save_root, "transcription-old.csv"), 177 | os.path.join(save_root, "transcription.csv"), 178 | ) 179 | -------------------------------------------------------------------------------- /toolkit/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import numpy as np 5 | 6 | 7 | # classification loss 8 | class CELoss(nn.Module): 9 | 10 | def __init__(self): 11 | super(CELoss, self).__init__() 12 | self.loss = nn.NLLLoss(reduction="sum") 13 | 14 | def forward(self, pred, target): 15 | pred = F.log_softmax(pred, 1) # [n_samples, n_classes] 16 | target = target.long() # [n_samples] 17 | loss = self.loss(pred, target) / len(pred) 18 | return loss 19 | 20 | 21 | # regression loss 22 | class MSELoss(nn.Module): 23 | 24 | def __init__(self): 25 | super(MSELoss, self).__init__() 26 | self.loss = nn.MSELoss(reduction="sum") 27 | 28 | def forward(self, pred, target): 29 | pred = pred.view(-1, 1) 30 | target = target.view(-1, 1) 31 | loss = self.loss(pred, target) / len(pred) 32 | return loss 33 | 34 | 35 | class CenterLoss(nn.Module): 36 | def __init__(self, num_classes, feat_dim, lambda_c=0.5): 37 | super(CenterLoss, self).__init__() 38 | self.num_classes = num_classes # 类别数 39 | self.feat_dim = feat_dim # 特征维度 40 | self.lambda_c = lambda_c # 平衡系数 41 | # 初始化每个类别的特征中心 42 | # self.centers = nn.Parameter(torch.randn(num_classes, feat_dim)) 43 | self.centers = nn.Parameter(torch.FloatTensor(num_classes, feat_dim)) 44 | nn.init.xavier_uniform_(self.centers) 45 | 46 | def forward(self, x, labels): 47 | """ 48 | x: 当前批次样本的特征向量 (batch_size, feat_dim) 49 | labels: 当前批次样本的类别标签 (batch_size) 50 | """ 51 | batch_size = x.size(0) 52 | 53 | # 取出当前批次样本对应类别的中心 54 | centers_batch = self.centers.cuda().index_select(0, labels) 55 | 56 | # 计算 Center Loss 57 | center_loss = ( 58 | self.lambda_c * 0.5 * torch.sum((x - centers_batch) ** 2) / batch_size 59 | ) 60 | return center_loss 61 | 62 | 63 | class MultiClassFocalLossWithAlpha(nn.Module): 64 | def __init__(self, alpha=[0.2, 0.3, 0.5], gamma=2, reduction="mean", classnum=None): 65 | """ 66 | :param alpha: 权重系数列表,三分类中第0类权重0.2,第1类权重0.3,第2类权重0.5 67 | :param gamma: 困难样本挖掘的gamma 68 | :param reduction: 69 | """ 70 | super(MultiClassFocalLossWithAlpha, self).__init__() 71 | if classnum is None: 72 | self.alpha = torch.tensor(alpha) 73 | else: 74 | self.alpha = torch.tensor([1.0 / classnum] * classnum) 75 | self.gamma = gamma 76 | self.reduction = reduction 77 | 78 | def forward(self, pred, target): 79 | alpha = self.alpha.cuda()[ 80 | target 81 | ] # 为当前batch内的样本,逐个分配类别权重,shape=(bs), 一维向量 82 | log_softmax = torch.log_softmax( 83 | pred, dim=1 84 | ) # 对模型裸输出做softmax再取log, shape=(bs, 3) 85 | logpt = torch.gather( 86 | log_softmax, dim=1, index=target.view(-1, 1) 87 | ) # 取出每个样本在类别标签位置的log_softmax值, shape=(bs, 1) 88 | logpt = logpt.view(-1) # 降维,shape=(bs) 89 | ce_loss = -logpt # 对log_softmax再取负,就是交叉熵了 90 | pt = torch.exp( 91 | logpt 92 | ) # 对log_softmax取exp,把log消了,就是每个样本在类别标签位置的softmax值了,shape=(bs) 93 | focal_loss = ( 94 | alpha * (1 - pt) ** self.gamma * ce_loss 95 | ) # 根据公式计算focal loss,得到每个样本的loss值,shape=(bs) 96 | if self.reduction == "mean": 97 | return torch.mean(focal_loss) 98 | if self.reduction == "sum": 99 | return torch.sum(focal_loss) 100 | return focal_loss 101 | 102 | 103 | class LDAMLoss(nn.Module): 104 | def __init__(self, cls_num_list, device, max_m=0.5, weight=None, s=30): 105 | super(LDAMLoss, self).__init__() 106 | print("LDAM weights:", weight) 107 | m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list)) 108 | m_list = m_list * (max_m / np.max(m_list)) 109 | m_list = torch.cuda.FloatTensor(m_list) 110 | self.m_list = m_list 111 | self.device = device 112 | assert s > 0 113 | self.s = s 114 | self.weight = weight 115 | 116 | def forward(self, x, target): 117 | index = torch.zeros_like(x, dtype=torch.uint8) 118 | index.scatter_(1, target.data.view(-1, 1), 1) 119 | 120 | index_float = index.type(torch.cuda.FloatTensor) 121 | batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0, 1)) 122 | batch_m = batch_m.view((-1, 1)) 123 | x_m = x - batch_m 124 | 125 | output = torch.where(index, x_m, x) 126 | return F.cross_entropy(self.s * output, target, weight=self.weight) 127 | # criterion = LDAMLoss(cls_num_list=a list of numer of samples in each class, max_m=0.5, s=30, weight=per_cls_weights) 128 | """ 129 | max_m: represents the margin used in the loss function. It controls the separation between different classes in the feature space. 130 | The appropriate value for max_m depends on the specific dataset and the severity of the class imbalance. 131 | You can start with a small value and gradually increase it to observe the impact on the model's performance. 132 | If the model struggles with class separation or experiences underfitting, increasing max_m might help. However, 133 | be cautious not to set it too high, as it can cause overfitting or make the model too conservative. 134 | 135 | s: higher s value results in sharper probabilities (more confident predictions), while a lower s value 136 | leads to more balanced probabilities. In practice, a larger s value can help stabilize training and improve convergence, 137 | especially when dealing with difficult optimization landscapes. 138 | The choice of s depends on the desired scale of the logits and the specific requirements of your problem. 139 | It can be used to adjust the balance between the margin and the original logits. A larger s value amplifies 140 | the impact of the logits and can be useful when dealing with highly imbalanced datasets. 141 | You can experiment with different values of s to find the one that works best for your dataset and model. 142 | 143 | """ 144 | 145 | 146 | class LMFLoss(nn.Module): 147 | def __init__( 148 | self, 149 | cls_num_list, 150 | device, 151 | weight=None, 152 | alpha=0.2, 153 | beta=0.2, 154 | gamma=2, 155 | max_m=0.8, 156 | s=5, 157 | add_LDAM_weigth=False, 158 | ): 159 | super().__init__() 160 | self.focal_loss = MultiClassFocalLossWithAlpha(classnum=len(cls_num_list)) 161 | if add_LDAM_weigth: 162 | LDAM_weight = weight 163 | else: 164 | LDAM_weight = None 165 | print( 166 | "LMF loss: alpha: ", 167 | alpha, 168 | " beta: ", 169 | beta, 170 | " gamma: ", 171 | gamma, 172 | " max_m: ", 173 | max_m, 174 | " s: ", 175 | s, 176 | " LDAM_weight: ", 177 | add_LDAM_weigth, 178 | ) 179 | self.ldam_loss = LDAMLoss(cls_num_list, device, max_m, weight=LDAM_weight, s=s) 180 | self.alpha = alpha 181 | self.beta = beta 182 | 183 | def forward(self, output, target): 184 | focal_loss_output = self.focal_loss(output, target) 185 | ldam_loss_output = self.ldam_loss(output, target) 186 | total_loss = self.alpha * focal_loss_output + self.beta * ldam_loss_output 187 | return total_loss 188 | -------------------------------------------------------------------------------- /feature_extraction/audio/extract_audio_huggingface.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import math 4 | import os 5 | 6 | # import config 7 | import sys 8 | import time 9 | 10 | import numpy as np 11 | import soundfile as sf 12 | import torch 13 | 14 | current_file_path = os.path.abspath(__file__) 15 | sys.path.append(os.path.dirname(os.path.dirname(current_file_path))) 16 | from transformers import AutoModel, Wav2Vec2FeatureExtractor, WhisperFeatureExtractor 17 | 18 | import config 19 | 20 | # supported models 21 | ################## ENGLISH ###################### 22 | WAV2VEC2_BASE = ( 23 | "wav2vec2-base-960h" # https://huggingface.co/facebook/wav2vec2-base-960h 24 | ) 25 | WAV2VEC2_LARGE = ( 26 | "wav2vec2-large-960h" # https://huggingface.co/facebook/wav2vec2-large-960h 27 | ) 28 | DATA2VEC_AUDIO_BASE = "data2vec-audio-base-960h" # https://huggingface.co/facebook/data2vec-audio-base-960h 29 | DATA2VEC_AUDIO_LARGE = ( 30 | "data2vec-audio-large" # https://huggingface.co/facebook/data2vec-audio-large 31 | ) 32 | 33 | ################## CHINESE ###################### 34 | HUBERT_BASE_CHINESE = ( 35 | "chinese-hubert-base" # https://huggingface.co/TencentGameMate/chinese-hubert-base 36 | ) 37 | HUBERT_LARGE_CHINESE = "chinese-hubert-large" # https://huggingface.co/TencentGameMate/chinese-hubert-large 38 | WAV2VEC2_BASE_CHINESE = "chinese-wav2vec2-base" # https://huggingface.co/TencentGameMate/chinese-wav2vec2-base 39 | WAV2VEC2_LARGE_CHINESE = "chinese-wav2vec2-large" # https://huggingface.co/TencentGameMate/chinese-wav2vec2-large 40 | 41 | ################## Multilingual ################# 42 | WAVLM_BASE = "wavlm-base" # https://huggingface.co/microsoft/wavlm-base 43 | WAVLM_LARGE = "wavlm-large" # https://huggingface.co/microsoft/wavlm-large 44 | WHISPER_BASE = "whisper-base" # https://huggingface.co/openai/whisper-base 45 | WHISPER_LARGE = "whisper-large-v2" # https://huggingface.co/openai/whisper-large-v2 46 | 47 | 48 | ## Target: avoid too long inputs 49 | # input_values: [1, wavlen], output: [bsize, maxlen] 50 | def split_into_batch(input_values, maxlen=16000 * 10): 51 | if len(input_values[0]) <= maxlen: 52 | return input_values 53 | 54 | bs, wavlen = input_values.shape 55 | assert bs == 1 56 | tgtlen = math.ceil(wavlen / maxlen) * maxlen 57 | batches = torch.zeros((1, tgtlen)) 58 | batches[:, :wavlen] = input_values 59 | batches = batches.view(-1, maxlen) 60 | return batches 61 | 62 | 63 | def extract(model_name, audio_files, save_dir, feature_level, gpu): 64 | 65 | start_time = time.time() 66 | 67 | # load model 68 | model_file = os.path.join( 69 | config.PATH_TO_PRETRAINED_MODELS, f"transformers/{model_name}" 70 | ) 71 | 72 | if model_name in [WHISPER_BASE, WHISPER_LARGE]: 73 | model = AutoModel.from_pretrained(model_file) 74 | feature_extractor = WhisperFeatureExtractor.from_pretrained(model_file) 75 | else: 76 | model = AutoModel.from_pretrained(model_file) 77 | feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_file) 78 | 79 | load_time = time.time() 80 | 81 | if gpu != -1: 82 | device = torch.device(f"cuda:{gpu}") 83 | model.to(device) 84 | model.eval() 85 | 86 | # iterate audios 87 | t1 = time.time() 88 | for idx, audio_file in enumerate(audio_files, 1): 89 | file_name = os.path.basename(audio_file) 90 | vid = file_name[:-4] 91 | print(f'Processing "{file_name}" ({idx}/{len(audio_files)})...') 92 | 93 | ## process for too short ones 94 | samples, sr = sf.read(audio_file) 95 | assert sr == 16000, "currently, we only test on 16k audio" 96 | 97 | ## model inference 98 | with torch.no_grad(): 99 | if model_name in [WHISPER_BASE, WHISPER_LARGE]: 100 | layer_ids = [-1] 101 | input_features = feature_extractor( 102 | samples, sampling_rate=sr, return_tensors="pt" 103 | ).input_features # [1, 80, 3000] 104 | decoder_input_ids = ( 105 | torch.tensor([[1, 1]]) * model.config.decoder_start_token_id 106 | ) 107 | if gpu != -1: 108 | input_features = input_features.to(device) 109 | if gpu != -1: 110 | decoder_input_ids = decoder_input_ids.to(device) 111 | last_hidden_state = model( 112 | input_features, decoder_input_ids=decoder_input_ids 113 | ).last_hidden_state 114 | assert last_hidden_state.shape[0] == 1 115 | feature = ( 116 | last_hidden_state[0].detach().squeeze().cpu().numpy() 117 | ) # (2, D) 118 | else: 119 | layer_ids = [-4, -3, -2, -1] 120 | input_values = feature_extractor( 121 | samples, sampling_rate=sr, return_tensors="pt" 122 | ).input_values # [1, wavlen] 123 | input_values = split_into_batch( 124 | input_values 125 | ) # [bsize, maxlen=10*16000] 126 | if gpu != -1: 127 | input_values = input_values.to(device) 128 | hidden_states = model( 129 | input_values, output_hidden_states=True 130 | ).hidden_states # tuple of (B, T, D) 131 | feature = torch.stack(hidden_states)[layer_ids].sum( 132 | dim=0 133 | ) # (B, T, D) # -> compress waveform channel 134 | bsize, segnum, featdim = feature.shape 135 | feature = ( 136 | feature.view(-1, featdim).detach().squeeze().cpu().numpy() 137 | ) # (B*T, D) 138 | 139 | ## store values 140 | csv_file = os.path.join(save_dir, f"{vid}.npy") 141 | if feature_level == "UTTERANCE": 142 | feature = np.array(feature).squeeze() 143 | if len(feature.shape) != 1: 144 | feature = np.mean(feature, axis=0) 145 | np.save(csv_file, feature) 146 | else: 147 | np.save(csv_file, feature) 148 | 149 | t2 = time.time() 150 | end_time = time.time() 151 | print(f"Model load time used: {load_time - start_time:.1f}s.") 152 | print(f"Total time used: {end_time - start_time:.1f}s.") 153 | print(f"average time used: {(t2 - t1)/len(audio_files):.1f}s.") 154 | 155 | 156 | if __name__ == "__main__": 157 | 158 | parser = argparse.ArgumentParser(description="Run.") 159 | parser.add_argument("--gpu", type=int, default=0, help="index of gpu") 160 | parser.add_argument( 161 | "--model_name", 162 | type=str, 163 | default="chinese-hubert-large", 164 | help="feature extractor", 165 | ) 166 | parser.add_argument( 167 | "--feature_level", type=str, default="FRAME", help="FRAME or UTTERANCE" 168 | ) 169 | parser.add_argument( 170 | "--dataset", 171 | type=str, 172 | default="/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/", 173 | help="input dataset", 174 | ) 175 | # ------ 临时测试SNR对于结果的影响 ------ 176 | parser.add_argument( 177 | "--noise_case", 178 | type=str, 179 | default=None, 180 | help="extract feature of different noise conditions", 181 | ) 182 | # ------ 临时测试 tts audio 对于结果的影响 ------- 183 | parser.add_argument( 184 | "--tts_lang", 185 | type=str, 186 | default=None, 187 | help="extract feature from tts audio, [chinese, english]", 188 | ) 189 | args = parser.parse_args() 190 | 191 | # analyze input 192 | audio_dir = os.path.join(args.dataset, "audio") 193 | save_dir = os.path.join(args.dataset, "features") 194 | 195 | if args.noise_case is not None: 196 | audio_dir += "_" + args.noise_case 197 | if args.tts_lang is not None: 198 | audio_dir += "-" + f"tts{args.tts_lang[:3]}16k" 199 | 200 | # audio_files 201 | audio_files = glob.glob(os.path.join(audio_dir, "*.wav")) 202 | print(f'Find total "{len(audio_files)}" audio files.') 203 | 204 | # save_dir 205 | if args.noise_case is not None: 206 | dir_name = f"{args.model_name}-noise{args.noise_case}-{args.feature_level[:3]}" 207 | elif args.tts_lang is not None: 208 | dir_name = f"{args.model_name}-tts{args.tts_lang[:3]}-{args.feature_level[:3]}" 209 | else: 210 | dir_name = f"{args.model_name}-{args.feature_level[:3]}" 211 | 212 | save_dir = os.path.join(save_dir, dir_name) 213 | if not os.path.exists(save_dir): 214 | os.makedirs(save_dir) 215 | 216 | # extract features 217 | extract(args.model_name, audio_files, save_dir, args.feature_level, gpu=args.gpu) 218 | -------------------------------------------------------------------------------- /feature_extraction/visual/extract_openface.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import shutil 5 | import pathlib 6 | import argparse 7 | import numpy as np 8 | from util import read_hog, read_csv 9 | 10 | import sys 11 | 12 | current_file_path = os.path.abspath(__file__) 13 | sys.path.append(os.path.dirname(os.path.dirname((current_file_path)))) 14 | 15 | import config 16 | 17 | 18 | def generate_face_faceDir(input_root, save_root, savetype="image"): 19 | if savetype == "image": 20 | for dir_path in sorted( 21 | glob.glob(input_root + "/*_aligned") 22 | ): # 'xx/xx/000100_guest_aligned' 23 | frame_names = os.listdir(dir_path) # ['xxx.bmp'] 24 | if len(frame_names) != 1: 25 | continue 26 | frame_path = os.path.join( 27 | dir_path, frame_names[0] 28 | ) # 'xx/xx/000100_guest_aligned/xxx.bmp' 29 | name = os.path.basename(dir_path)[: -len("_aligned")] # '000100_guest' 30 | save_path = os.path.join(save_root, name + ".bmp") 31 | shutil.copy(frame_path, save_path) 32 | elif savetype == "npy": 33 | frames = [] 34 | for dir_path in sorted( 35 | glob.glob(input_root + "/*_aligned") 36 | ): # 'xx/xx/000100_guest_aligned' 37 | frame_names = os.listdir(dir_path) # ['xxx.bmp'] 38 | if len(frame_names) != 1: 39 | continue 40 | frame_path = os.path.join( 41 | dir_path, frame_names[0] 42 | ) # 'xx/xx/000100_guest_aligned/xxx.bmp' 43 | frame = cv2.imread(frame_path) 44 | frames.append(frame) 45 | videoname = os.path.basename(input_root) 46 | save_path = os.path.join(save_root, videoname + ".npy") 47 | np.save(save_path, frames) 48 | 49 | 50 | def generate_face_videoOne(input_root, save_root, savetype="image"): 51 | for dir_path in glob.glob( 52 | input_root + "/*_aligned" 53 | ): # 'xx/xx/000100_guest_aligned' 54 | frame_names = sorted(os.listdir(dir_path)) # ['xxx.bmp'] 55 | if savetype == "image": 56 | for ii in range(len(frame_names)): 57 | frame_path = os.path.join( 58 | dir_path, frame_names[ii] 59 | ) # 'xx/xx/000100_guest_aligned/xxx.bmp' 60 | frame_name = os.path.basename(frame_path) 61 | save_path = os.path.join(save_root, frame_name) 62 | shutil.copy(frame_path, save_path) 63 | elif savetype == "npy": 64 | frames = [] 65 | for ii in range(len(frame_names)): 66 | frame_path = os.path.join(dir_path, frame_names[ii]) 67 | frame = cv2.imread(frame_path) 68 | frames.append(frame) 69 | videoname = os.path.basename(input_root) 70 | save_path = os.path.join(save_root, videoname + ".npy") 71 | np.save(save_path, frames) 72 | 73 | 74 | def generate_hog(input_root, save_root): 75 | for hog_path in glob.glob(input_root + "/*.hog"): 76 | csv_path = hog_path[:-4] + ".csv" 77 | if os.path.exists(csv_path): 78 | hog_name = os.path.basename(hog_path)[:-4] 79 | _, feature = read_hog(hog_path) 80 | save_path = os.path.join(save_root, hog_name + ".npy") 81 | np.save(save_path, feature) 82 | 83 | 84 | def generate_csv(input_root, save_root, startIdx): 85 | for csv_path in glob.glob(input_root + "/*.csv"): 86 | csv_name = os.path.basename(csv_path)[:-4] 87 | feature = read_csv(csv_path, startIdx) 88 | save_path = os.path.join(save_root, csv_name + ".npy") 89 | np.save(save_path, feature) 90 | 91 | 92 | # name_npy: only process on names in 'name_npy' 93 | def extract( 94 | input_dir, process_type, save_dir, face_dir, hog_dir, pose_dir, name_npy=None 95 | ): 96 | 97 | # => process_names 98 | if name_npy is not None: # 指定特定的文件进行处理 99 | process_names = np.load(name_npy) 100 | else: # 处理所有视频文件 101 | vids = os.listdir(input_dir) 102 | process_names = [vid.rsplit(".", 1)[0] for vid in vids] 103 | print(f"processing names: {len(process_names)}") 104 | 105 | # process folders 106 | vids = os.listdir(input_dir) 107 | print(f'Find total "{len(vids)}" videos.') 108 | for i, vid in enumerate(vids, 1): 109 | saveVid = vid.rsplit(".", 1)[0] # unify folder and video names 110 | if saveVid not in process_names: 111 | continue 112 | 113 | print(f"Processing video '{vid}' ({i}/{len(vids)})...") 114 | input_root = os.path.join(input_dir, vid) # exists 115 | save_root = os.path.join(save_dir, saveVid) 116 | face_root = os.path.join(face_dir, saveVid) 117 | hog_root = os.path.join(hog_dir, saveVid) 118 | pose_root = os.path.join(pose_dir, saveVid) 119 | if os.path.exists(face_root): 120 | continue 121 | if not os.path.exists(save_root): 122 | os.makedirs(save_root) 123 | if not os.path.exists(face_root): 124 | os.makedirs(face_root) 125 | if not os.path.exists(hog_root): 126 | os.makedirs(hog_root) 127 | if not os.path.exists(pose_root): 128 | os.makedirs(pose_root) 129 | if process_type == "faceDir": 130 | exe_path = os.path.join(config.PATH_TO_OPENFACE_Win, "FaceLandmarkImg.exe") 131 | commond = '%s -fdir "%s" -out_dir "%s"' % (exe_path, input_root, save_root) 132 | os.system(commond) 133 | # generate_face_faceDir(save_root, face_root, savetype='image') # more subtle files 134 | generate_face_faceDir( 135 | save_root, face_root, savetype="npy" 136 | ) # compress frames into npy 137 | # generate_hog(save_root, hog_root) # not used 138 | # generate_csv(save_root, pose_root, startIdx=2) # not used 139 | ## delete temp folder 140 | dir_path = pathlib.Path(save_root) 141 | shutil.rmtree(dir_path) 142 | elif process_type == "videoOne": 143 | exe_path = os.path.join( 144 | config.PATH_TO_OPENFACE_Win, "build/bin/FeatureExtraction" 145 | ) # jxie 146 | commond = '%s -f "%s" -out_dir "%s"' % (exe_path, input_root, save_root) 147 | os.system(commond) 148 | # generate_face_videoOne(save_root, face_root, savetype='image') # more subtle files 149 | generate_face_videoOne( 150 | save_root, face_root, savetype="npy" 151 | ) # compress frames into npy 152 | # generate_hog(save_root, hog_root) 153 | # generate_csv(save_root, pose_root, startIdx=5) 154 | ## delete temp folder 155 | dir_path = pathlib.Path(save_root) 156 | shutil.rmtree(dir_path) 157 | 158 | 159 | if __name__ == "__main__": 160 | parser = argparse.ArgumentParser(description="Run.") 161 | parser.add_argument( 162 | "--overwrite", 163 | action="store_true", 164 | default=True, 165 | help="whether overwrite existed feature folder.", 166 | ) 167 | parser.add_argument( 168 | "--dataset", 169 | type=str, 170 | default="/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/", 171 | help="input dataset dir path", 172 | ) 173 | parser.add_argument("--name_npy", type=str, default=None, help="process name lists") 174 | parser.add_argument( 175 | "--type", 176 | type=str, 177 | default="videoOne", 178 | choices=["faceDir", "videoOne"], 179 | help="faceDir: process on facedirs; videoOne: process on one video", 180 | ) 181 | params = parser.parse_args() 182 | 183 | print(f"==> Extracting openface features...") 184 | 185 | # in: face dir 186 | dataset = params.dataset 187 | process_type = params.type 188 | input_dir = os.path.join(dataset, "video") 189 | 190 | # out: feature csv dir 191 | save_dir = os.path.join(dataset, "features", "openface_all") 192 | hog_dir = os.path.join(dataset, "features", "openface_hog") 193 | pose_dir = os.path.join(dataset, "features", "openface_pose") 194 | face_dir = os.path.join(dataset, "features", "openface_face") 195 | 196 | if not os.path.exists(save_dir): 197 | os.makedirs(save_dir) 198 | if not os.path.exists(hog_dir): 199 | os.makedirs(hog_dir) 200 | if not os.path.exists(pose_dir): 201 | os.makedirs(pose_dir) 202 | if not os.path.exists(face_dir): 203 | os.makedirs(face_dir) 204 | 205 | # process 206 | extract( 207 | input_dir, process_type, save_dir, face_dir, hog_dir, pose_dir, params.name_npy 208 | ) 209 | 210 | print(f"==> Finish") 211 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import logging 4 | import os 5 | import os.path as osp 6 | import random 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import torch 11 | import torch.optim as optim 12 | from omegaconf import OmegaConf 13 | from sklearn.metrics import ( 14 | accuracy_score, 15 | confusion_matrix, 16 | f1_score, 17 | precision_score, 18 | recall_score, 19 | ) 20 | from tqdm import tqdm 21 | 22 | import config 23 | from dataloader.dataloader import get_dataloader_train, get_dataloader_valid 24 | from models import get_models 25 | from toolkit.utils.draw_process import draw_loss, draw_metric 26 | from toolkit.utils.eval import calculate_results 27 | from toolkit.utils.functions import func_update_storage, merge_args_config 28 | from toolkit.utils.loss import CELoss, MSELoss 29 | from toolkit.utils.metric import gain_metric_from_results 30 | 31 | # emotion rule 32 | # emotions = ["neutral", "angry", "happy", "sad", "worried", "surprise"] 33 | 34 | 35 | def set_seed(seed=0): 36 | torch.manual_seed(seed) 37 | torch.cuda.manual_seed(seed) 38 | torch.cuda.manual_seed_all(seed) 39 | np.random.seed(seed) 40 | random.seed(seed) 41 | torch.backends.cudnn.benchmark = False 42 | torch.backends.cudnn.deterministic = True 43 | 44 | 45 | def test_model( 46 | args, 47 | model, 48 | dataloader, 49 | ): 50 | 51 | vidnames = [] 52 | val_preds, val_labels = [], [] 53 | emo_probs, emo_labels = [], [] 54 | losses = [] 55 | 56 | model.eval() 57 | 58 | pbar = tqdm(range(len(dataloader)), desc=f"test") 59 | 60 | for iter, data in enumerate(dataloader): 61 | 62 | # read data + cuda 63 | audios, texts, videos, emos, vals, bnames = data 64 | 65 | vidnames += bnames 66 | 67 | batch = {} 68 | batch["videos"] = videos.float().cuda() 69 | batch["audios"] = audios.float().cuda() 70 | batch["texts"] = texts.float().cuda() 71 | 72 | emos = emos.long().cuda() 73 | vals = vals.float().cuda() 74 | 75 | if args.train_input_mode == "input_gt": 76 | _, emos_out, _, _ = model([batch, emos]) 77 | elif args.train_input_mode == "input": 78 | _, emos_out, _, _ = model(batch) 79 | 80 | emo_probs.append(emos_out.data.cpu().numpy()) 81 | emo_labels.append(emos.data.cpu().numpy()) 82 | 83 | pbar.update(1) 84 | 85 | pbar.close() 86 | 87 | if emo_probs != []: 88 | emo_probs = np.concatenate(emo_probs) 89 | if emo_labels != []: 90 | emo_labels = np.concatenate(emo_labels) 91 | if val_preds != []: 92 | val_preds = np.concatenate(val_preds) 93 | if val_labels != []: 94 | val_labels = np.concatenate(val_labels) 95 | results, _ = calculate_results(emo_probs, emo_labels, val_preds, val_labels) 96 | save_results = dict( 97 | names=vidnames, 98 | **results, 99 | ) 100 | 101 | y_true = [] 102 | y_pred = [] 103 | emo_preds = np.argmax(emo_probs, 1) 104 | for emo_pred, emo_label in zip(emo_preds, emo_labels): 105 | y_pred.append(emotions[emo_pred]) 106 | y_true.append(emotions[emo_label]) 107 | save_results["emotions"] = emotions 108 | conf_matrix = confusion_matrix(y_true, y_pred, labels=emotions) 109 | accuracy = accuracy_score(y_true, y_pred) 110 | precision = precision_score(y_true, y_pred, average="weighted") 111 | recall = recall_score(y_true, y_pred, average="weighted") 112 | f1 = f1_score(y_true, y_pred, average="weighted") 113 | print("emotions: " + str(emotions)) 114 | print("Confusion Matrix:") 115 | print(str(conf_matrix)) 116 | print(f"Accuracy: {accuracy:.4f}") 117 | print(f"Precision: {precision:.4f}") 118 | print(f"Recall: {recall:.4f}") 119 | print(f"F1 Score: {f1:.4f}") 120 | return save_results 121 | 122 | 123 | if __name__ == "__main__": 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument( 126 | "--seed", 127 | type=int, 128 | default=0, 129 | help="random seed", 130 | ) 131 | 132 | # --- data config 133 | parser.add_argument( 134 | "--dataset", 135 | type=str, 136 | default="MER24-test_3A_whisper-base-UTT", 137 | help="dataset info name", 138 | ) 139 | parser.add_argument( 140 | "--emo_rule", 141 | type=str, 142 | default="MER", 143 | help="emo map function from emotion to index", 144 | ) 145 | 146 | # --- save config 147 | parser.add_argument( 148 | "--save_root", 149 | type=str, 150 | default="/mnt/public/gxj/EmoNets/saved_test", 151 | help="save prediction results and models", 152 | ) 153 | parser.add_argument( 154 | "--save_name", 155 | type=str, 156 | default="test", 157 | help="save prediction results and models", 158 | ) 159 | 160 | # --- model config 161 | parser.add_argument( 162 | "--model", 163 | type=str, 164 | default="attention", 165 | help="model name for training [attention, mer_rank5, and others]", 166 | ) 167 | parser.add_argument( 168 | "--load_key", 169 | type=str, 170 | default="/mnt/public/gxj/EmoNets/saved/seed0_MER24_3A_whisper-base-UTT/attention/2025-04-06-21-26-53", 171 | help="keyword about which model weight to load", 172 | ) 173 | 174 | # --- test sets 175 | parser.add_argument( 176 | "--batch_size", 177 | type=int, 178 | default=512, 179 | metavar="BS", 180 | help="batch size", 181 | ) 182 | parser.add_argument( 183 | "--num_workers", type=int, default=4, metavar="nw", help="number of workers" 184 | ) 185 | parser.add_argument("--gpu", default=0, type=int, help="GPU id to use") 186 | 187 | args = parser.parse_args() 188 | 189 | set_seed(args.seed) 190 | 191 | torch.cuda.set_device(args.gpu) 192 | 193 | # 若没有关键词,则使用load key的路径 194 | if args.load_key not in config.MODEL_DIR_DICT[args.model].keys(): 195 | config.MODEL_DIR_DICT[args.model][args.load_key] = args.load_key 196 | 197 | # 用当前设置参数覆盖训练时的同名设置 198 | load_args_path = osp.join( 199 | config.MODEL_DIR_DICT[args.model][args.load_key], "best_args.npy" 200 | ) 201 | load_args = np.load(load_args_path, allow_pickle=True).item()["args"] 202 | load_args_dic = vars(load_args) 203 | args_dic = vars(args) 204 | for key in args_dic: 205 | load_args_dic[key] = args_dic[key] 206 | args = argparse.Namespace(**load_args_dic) # 两部分参数重叠 207 | 208 | print("====== set save dir =======") 209 | now = datetime.datetime.now() 210 | save_dir = os.path.join(args.save_root, args.dataset, args.model, args.save_name) 211 | 212 | if not os.path.exists(save_dir): 213 | os.makedirs(save_dir) 214 | print("{}\n".format(save_dir)) 215 | 216 | logging_path = osp.join(save_dir, "info.log") 217 | logging.basicConfig( 218 | level=logging.INFO, 219 | format="%(levelname)s: %(message)s", 220 | filename=logging_path, 221 | filemode="w", 222 | ) 223 | 224 | logging.info("====== load info and config =======") 225 | dataset_info = np.load( 226 | osp.join(config.PATH_TO_TEST_LST, args.dataset + ".npy"), allow_pickle=True 227 | ).item() 228 | 229 | for feat_name in ["video", "audio", "text"]: 230 | logging.info( 231 | "Input feature: {} ===> dim is (1, {})".format( 232 | dataset_info["feat_dim"][f"{feat_name}_feat_description"], 233 | dataset_info["feat_dim"][f"{feat_name}_dim"], 234 | ) 235 | ) 236 | 237 | emo_rule = config.EMO_RULE[args.emo_rule] 238 | emotions = list(emo_rule) 239 | 240 | logging.info("====== load dataset =======") 241 | test_info = dataset_info["test_info"] 242 | test_loader = get_dataloader_valid( 243 | names=test_info["names"], 244 | emo_labels=test_info["emos"], 245 | val_labels=test_info["vals"], 246 | emo_rule=emo_rule, 247 | audio_feat_paths=test_info["audio_feat"], 248 | video_feat_paths=test_info["video_feat"], 249 | text_feat_paths=test_info["text_feat"], 250 | batch_size=args.batch_size, 251 | num_workers=args.num_workers, 252 | shuffle=False, 253 | ) 254 | 255 | logging.info("====== build model =======") 256 | model = get_models(args).cuda() 257 | assert args.load_key is not None 258 | model_path = osp.join(config.MODEL_DIR_DICT[args.model][args.load_key], "best.pth") 259 | model.load_state_dict( 260 | torch.load(model_path, map_location=f"cuda:0", weights_only=True) 261 | ) 262 | 263 | logging.info("load model: {} \n".format(args.model)) 264 | logging.info("load model weight: {} \n".format(model_path)) 265 | 266 | logging.info("====== Evaluation =======") 267 | best_eval_metric = None # for select best model weight 268 | record_test = {"emoacc": [], "emofscore": []} 269 | 270 | test_results = test_model( 271 | args, 272 | model, 273 | test_loader, 274 | ) 275 | 276 | # np.save(osp.join(save_dir, "results.npy"), test_results) 277 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /feature_extraction/visual/extract_vision_huggingface.py: -------------------------------------------------------------------------------- 1 | # *_*coding:utf-8 *_* 2 | import os 3 | import cv2 4 | import math 5 | import argparse 6 | import numpy as np 7 | from PIL import Image 8 | 9 | import torch 10 | import timm # pip install timm==0.9.7 11 | from transformers import AutoModel, AutoFeatureExtractor, AutoImageProcessor 12 | import time 13 | 14 | # import config 15 | import sys 16 | 17 | current_file_path = os.path.abspath(__file__) 18 | sys.path.append(os.path.dirname(os.path.dirname(current_file_path))) 19 | 20 | import config 21 | 22 | ##################### Pretrained models ##################### 23 | CLIP_VIT_BASE = ( 24 | "clip-vit-base-patch32" # https://huggingface.co/openai/clip-vit-base-patch32 25 | ) 26 | CLIP_VIT_LARGE = ( 27 | "clip-vit-large-patch14" # https://huggingface.co/openai/clip-vit-large-patch14 28 | ) 29 | EVACLIP_VIT = "eva02_base_patch14_224.mim_in22k" # https://huggingface.co/timm/eva02_base_patch14_224.mim_in22k 30 | DATA2VEC_VISUAL = "data2vec-vision-base-ft1k" # https://huggingface.co/facebook/data2vec-vision-base-ft1k 31 | VIDEOMAE_BASE = "videomae-base" # https://huggingface.co/MCG-NJU/videomae-base 32 | VIDEOMAE_LARGE = "videomae-large" # https://huggingface.co/MCG-NJU/videomae-large 33 | DINO2_LARGE = "dinov2-large" # https://huggingface.co/facebook/dinov2-large 34 | DINO2_GIANT = "dinov2-giant" # https://huggingface.co/facebook/dinov2-giant 35 | 36 | 37 | def func_opencv_to_image(img): 38 | img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) 39 | return img 40 | 41 | 42 | def func_opencv_to_numpy(img): 43 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 44 | return img 45 | 46 | 47 | def func_read_frames(face_dir, vid): 48 | npy_path = os.path.join(face_dir, vid, f"{vid}.npy") 49 | assert os.path.exists(npy_path), f"Error: {vid} does not have frames.npy!" 50 | frames = np.load(npy_path) 51 | return frames 52 | 53 | 54 | # 策略2:VideoMAE修订,采用孙总提供的采样代码 55 | def func_get_indexes_for_videomae(total_frames): 56 | 57 | clip_len = 16 58 | test_num_segment = 2 59 | frame_sample_rate = 4 60 | 61 | # sampling according to 'frame_sample_rate' 62 | all_index = [x for x in range(0, total_frames, frame_sample_rate)] 63 | while len(all_index) < clip_len: 64 | all_index.append(all_index[-1]) 65 | 66 | # get start end index according to 'segment_id' => 至少可以保证后面有 16 frames 67 | whole_segment_index = [] 68 | for segment_id in range(test_num_segment): 69 | temporal_step = max( 70 | 1.0 * (len(all_index) - clip_len) / (test_num_segment - 1), 0 71 | ) 72 | temporal_start = int(segment_id * temporal_step) 73 | segment_index = all_index[temporal_start : temporal_start + clip_len] 74 | print(f"segment '{segment_id+1}' index: {segment_index}") 75 | whole_segment_index.append(segment_index) 76 | return whole_segment_index 77 | 78 | 79 | def resample_frames_sunlicai(frames): 80 | batches = [] 81 | whole_segment_index = func_get_indexes_for_videomae(len(frames)) 82 | for segment_index in whole_segment_index: 83 | assert len(segment_index) == 16 84 | batches.append(np.array(frames)[segment_index]) 85 | return batches 86 | 87 | 88 | # 策略3:相比于上面采样更加均匀 [将videomae替换并重新测试] 89 | def resample_frames_uniform(frames, nframe=16): 90 | vlen = len(frames) 91 | start, end = 0, vlen 92 | 93 | n_frms_update = min(nframe, vlen) # for vlen < n_frms, only read vlen 94 | indices = np.arange(start, end, vlen / n_frms_update).astype(int).tolist() 95 | 96 | # whether compress into 'n_frms' 97 | while len(indices) < nframe: 98 | indices.append(indices[-1]) 99 | indices = indices[:nframe] 100 | assert len(indices) == nframe, f"{indices}, {vlen}, {nframe}" 101 | return frames[indices] 102 | 103 | 104 | def split_into_batch(inputs, bsize=32): 105 | batches = [] 106 | for ii in range(math.ceil(len(inputs) / bsize)): 107 | batch = inputs[ii * bsize : (ii + 1) * bsize] 108 | batches.append(batch) 109 | return batches 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser(description="Run.") 114 | parser.add_argument( 115 | "--dataset", 116 | type=str, 117 | default="/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/", 118 | help="input dataset", 119 | ) 120 | parser.add_argument( 121 | "--model_name", 122 | type=str, 123 | default="clip-vit-large-patch14", 124 | help="name of pretrained model", 125 | ) 126 | parser.add_argument( 127 | "--feature_level", 128 | type=str, 129 | default="UTTERANCE", 130 | help="feature level [FRAME or UTTERANCE]", 131 | ) 132 | parser.add_argument( 133 | "--videomae_type", 134 | type=str, 135 | default=None, 136 | help="videomae input type: [None or sunlicai]", 137 | ) 138 | parser.add_argument("--gpu", type=int, default=0, help="gpu id") 139 | params = parser.parse_args() 140 | 141 | print(f"==> Extracting {params.model_name} embeddings...") 142 | model_name = params.model_name.split(".")[0] 143 | face_dir = os.path.join(params.dataset, "features", "openface_face") 144 | 145 | start_time = time.time() 146 | 147 | # gain save_dir 148 | if params.videomae_type is None: 149 | save_dir = os.path.join( 150 | params.dataset, 151 | "features", 152 | f"{model_name}-{params.feature_level[:3]}", 153 | ) 154 | else: 155 | assert params.videomae_type == "sunlicai" 156 | save_dir = os.path.join( 157 | params.dataset, 158 | "features", 159 | f"sunlicai-{model_name}-{params.feature_level[:3]}", 160 | ) 161 | if not os.path.exists(save_dir): 162 | os.makedirs(save_dir) 163 | 164 | # load model 165 | if params.model_name in [ 166 | CLIP_VIT_BASE, 167 | CLIP_VIT_LARGE, 168 | DATA2VEC_VISUAL, 169 | VIDEOMAE_BASE, 170 | VIDEOMAE_LARGE, 171 | ]: # from huggingface 172 | model_dir = os.path.join( 173 | config.PATH_TO_PRETRAINED_MODELS, f"transformers/{params.model_name}" 174 | ) 175 | model = AutoModel.from_pretrained(model_dir) 176 | processor = AutoFeatureExtractor.from_pretrained(model_dir) 177 | elif params.model_name in [DINO2_LARGE, DINO2_GIANT]: 178 | model_dir = os.path.join( 179 | config.PATH_TO_PRETRAINED_MODELS, f"transformers/{params.model_name}" 180 | ) 181 | model = AutoModel.from_pretrained(model_dir) 182 | processor = AutoImageProcessor.from_pretrained(model_dir) 183 | elif params.model_name in [EVACLIP_VIT]: # from timm 184 | model_path = os.path.join( 185 | config.PATH_TO_PRETRAINED_MODELS, 186 | f"timm/{params.model_name}/model.safetensors", 187 | ) 188 | model = timm.create_model( 189 | params.model_name, 190 | pretrained=True, 191 | num_classes=0, 192 | pretrained_cfg_overlay=dict(file=model_path), 193 | ) 194 | data_config = timm.data.resolve_model_data_config(model) 195 | transforms = timm.data.create_transform(**data_config, is_training=False) 196 | 197 | # 有 gpu 才会放在cuda上 198 | if params.gpu != -1: 199 | torch.cuda.set_device(params.gpu) 200 | model.cuda() 201 | model.eval() 202 | 203 | load_time = time.time() 204 | 205 | # extract embedding video by video 206 | vids = os.listdir(face_dir) 207 | EMBEDDING_DIM = -1 208 | print(f'Find total "{len(vids)}" videos.') 209 | for i, vid in enumerate(vids, 1): 210 | print(f"Processing video '{vid}' ({i}/{len(vids)})...") 211 | # save_file = os.path.join(save_dir, f'{vid}.npy') 212 | # if os.path.exists(save_file): continue 213 | 214 | # forward process [different model has its unique mode, it is hard to unify them as one process] 215 | # => split into batch to reduce memory usage 216 | with torch.no_grad(): 217 | frames = func_read_frames(face_dir, vid) 218 | if params.model_name in [CLIP_VIT_BASE, CLIP_VIT_LARGE]: 219 | frames = [func_opencv_to_image(frame) for frame in frames] 220 | inputs = processor(images=frames, return_tensors="pt")["pixel_values"] 221 | if params.gpu != -1: 222 | inputs = inputs.to("cuda") 223 | batches = split_into_batch(inputs, bsize=32) 224 | embeddings = [] 225 | for batch in batches: 226 | embeddings.append(model.get_image_features(batch)) # [58, 768] 227 | embeddings = torch.cat(embeddings, axis=0) # [frames_num, 768] 228 | 229 | elif params.model_name in [DATA2VEC_VISUAL]: 230 | frames = [func_opencv_to_image(frame) for frame in frames] 231 | inputs = processor(images=frames, return_tensors="pt")[ 232 | "pixel_values" 233 | ] # [nframe, 3, 224, 224] 234 | if params.gpu != -1: 235 | inputs = inputs.to("cuda") 236 | batches = split_into_batch(inputs, bsize=32) 237 | embeddings = [] 238 | for batch in batches: # [32, 3, 224, 224] 239 | hidden_states = model( 240 | batch, output_hidden_states=True 241 | ).hidden_states # [58, 196 patch + 1 cls, feat=768] 242 | embeddings.append( 243 | torch.stack(hidden_states)[-1].sum(dim=1) 244 | ) # [58, 768] 245 | embeddings = torch.cat(embeddings, axis=0) # [frames_num, 768] 246 | 247 | elif params.model_name in [DINO2_LARGE, DINO2_GIANT]: 248 | frames = resample_frames_uniform( 249 | frames, nframe=64 250 | ) # 加速特征提起:这种方式更加均匀的采样64帧 251 | frames = [func_opencv_to_image(frame) for frame in frames] 252 | inputs = processor(images=frames, return_tensors="pt")[ 253 | "pixel_values" 254 | ] # [nframe, 3, 224, 224] 255 | if params.gpu != -1: 256 | inputs = inputs.to("cuda") 257 | batches = split_into_batch(inputs, bsize=32) 258 | embeddings = [] 259 | for batch in batches: # [32, 3, 224, 224] 260 | hidden_states = model( 261 | batch, output_hidden_states=True 262 | ).hidden_states # [58, 196 patch + 1 cls, feat=768] 263 | embeddings.append( 264 | torch.stack(hidden_states)[-1].sum(dim=1) 265 | ) # [58, 768] 266 | embeddings = torch.cat(embeddings, axis=0) # [frames_num, 768] 267 | 268 | elif params.model_name in [VIDEOMAE_BASE, VIDEOMAE_LARGE]: 269 | # videoVAE: only supports 16 frames inputs 270 | if params.videomae_type == "sunlicai": 271 | batches = resample_frames_sunlicai(frames) 272 | else: 273 | batches = [ 274 | resample_frames_uniform(frames) 275 | ] # convert to list of batches 276 | embeddings = [] 277 | for batch in batches: 278 | frames = [ 279 | func_opencv_to_numpy(frame) for frame in batch 280 | ] # 16 * [112, 112, 3] 281 | inputs = processor(list(frames), return_tensors="pt")[ 282 | "pixel_values" 283 | ] # [1, 16, 3, 224, 224] 284 | if params.gpu != -1: 285 | inputs = inputs.to("cuda") 286 | outputs = model(inputs).last_hidden_state # [1, 1586, 768] 287 | num_patches_per_frame = ( 288 | model.config.image_size // model.config.patch_size 289 | ) ** 2 # 14*14 290 | outputs = outputs.view( 291 | 16 // model.config.tubelet_size, num_patches_per_frame, -1 292 | ) # [seg_number, patch, featdim] 293 | embeddings.append(outputs.mean(dim=1)) # [seg_number, featdim] 294 | embeddings = torch.cat(embeddings, axis=0) 295 | 296 | elif params.model_name in [EVACLIP_VIT]: 297 | frames = [func_opencv_to_image(frame) for frame in frames] 298 | inputs = torch.stack( 299 | [transforms(frame) for frame in frames] 300 | ) # [117, 3, 224, 224] 301 | if params.gpu != -1: 302 | inputs = inputs.to("cuda") 303 | batches = split_into_batch(inputs, bsize=32) 304 | embeddings = [] 305 | for batch in batches: 306 | embeddings.append(model(batch)) # [58, 768] 307 | embeddings = torch.cat(embeddings, axis=0) # [frames_num, 768] 308 | 309 | embeddings = embeddings.detach().squeeze().cpu().numpy() 310 | EMBEDDING_DIM = max(EMBEDDING_DIM, np.shape(embeddings)[-1]) 311 | 312 | # save into npy 313 | save_file = os.path.join(save_dir, f"{vid}.npy") 314 | if params.feature_level == "FRAME": 315 | embeddings = np.array(embeddings).squeeze() 316 | if len(embeddings) == 0: 317 | embeddings = np.zeros((1, EMBEDDING_DIM)) 318 | elif len(embeddings.shape) == 1: 319 | embeddings = embeddings[np.newaxis, :] 320 | np.save(save_file, embeddings) 321 | else: 322 | embeddings = np.array(embeddings).squeeze() 323 | if len(embeddings) == 0: 324 | embeddings = np.zeros((EMBEDDING_DIM,)) 325 | elif len(embeddings.shape) == 2: 326 | embeddings = np.mean(embeddings, axis=0) 327 | np.save(save_file, embeddings) 328 | 329 | end_time = time.time() 330 | print(f"Model load time used: {load_time - start_time:.1f}s.") 331 | print(f"Total time used: {end_time - start_time:.1f}s.") 332 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | import datetime 4 | import logging 5 | import os 6 | import os.path as osp 7 | import random 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | import torch.optim as optim 13 | from omegaconf import OmegaConf 14 | from tqdm import tqdm 15 | 16 | import config 17 | from dataloader.dataloader import get_dataloader_train, get_dataloader_valid 18 | from models import get_models 19 | from toolkit.utils.draw_process import draw_loss, draw_metric 20 | from toolkit.utils.eval import calculate_results 21 | from toolkit.utils.functions import func_update_storage, merge_args_config 22 | from toolkit.utils.loss import CELoss, MSELoss 23 | from toolkit.utils.metric import gain_metric_from_results 24 | 25 | 26 | def set_seed(seed=0): 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | np.random.seed(seed) 31 | random.seed(seed) 32 | torch.backends.cudnn.benchmark = False 33 | torch.backends.cudnn.deterministic = True 34 | 35 | 36 | def save_model(model, save_path): 37 | if hasattr(model, "module"): 38 | model_state = model.module.state_dict() 39 | else: 40 | model_state = model.state_dict() 41 | 42 | torch.save( 43 | model_state, 44 | osp.join(save_path), 45 | ) 46 | return 47 | 48 | 49 | def train_or_eval_model( 50 | args, model, reg_loss, cls_loss, dataloader, epoch, optimizer=None, train=False 51 | ): 52 | 53 | if not train: 54 | vidnames = [] 55 | val_preds, val_labels = [], [] 56 | emo_probs, emo_labels = [], [] 57 | losses = [] 58 | 59 | assert not train or optimizer != None 60 | 61 | if train: 62 | model.train() 63 | else: 64 | model.eval() 65 | 66 | if train: 67 | pbar = tqdm(range(len(dataloader)), desc=f"epoch:{epoch+1}, train") 68 | else: 69 | pbar = tqdm(range(len(dataloader)), desc=f"epoch:{epoch+1}, valid") 70 | 71 | for iter, data in enumerate(dataloader): 72 | if train: 73 | optimizer.zero_grad() 74 | 75 | # read data + cuda 76 | inputs, emos, vals, bnames = data 77 | 78 | if not train: 79 | vidnames += bnames 80 | 81 | for k in inputs.keys(): 82 | inputs[k] = inputs[k].cuda() 83 | 84 | emos = emos.long().cuda() 85 | vals = vals.float().cuda() 86 | 87 | # forward process 88 | # start_time = time.time() 89 | 90 | if args.train_input_mode == "input_gt": 91 | features, emos_out, vals_out, interloss = model([inputs, emos]) 92 | elif args.train_input_mode == "input": 93 | features, emos_out, vals_out, interloss = model(inputs) 94 | # duration = time.time() - start_time 95 | # macs, params = profile(model, inputs=(batch, )) 96 | # print(f"MACs: {macs}, Parameters: {params}, Duration: {duration}; bsize: {len(bnames)}") 97 | 98 | # loss calculation 99 | loss = interloss 100 | 101 | if args.output_dim1 != 0: 102 | loss = loss + cls_loss(emos_out, emos) 103 | if not train: 104 | emo_probs.append(emos_out.data.cpu().numpy()) 105 | emo_labels.append(emos.data.cpu().numpy()) 106 | if args.output_dim2 != 0: 107 | loss = loss + reg_loss(vals_out, vals) 108 | if not train: 109 | val_preds.append(vals_out.data.cpu().numpy()) 110 | val_labels.append(vals.data.cpu().numpy()) 111 | losses.append(loss.data.cpu().numpy()) 112 | 113 | pbar.set_postfix(**{"loss": loss.item()}) 114 | pbar.update(1) 115 | 116 | # optimize params 117 | if train: 118 | loss.backward() 119 | if model.model.grad_clip != -1: 120 | torch.nn.utils.clip_grad_value_( 121 | [param for param in model.parameters() if param.requires_grad], 122 | model.model.grad_clip, 123 | ) 124 | optimizer.step() 125 | 126 | pbar.close() 127 | 128 | if not train: 129 | if emo_probs != []: 130 | emo_probs = np.concatenate(emo_probs) 131 | if emo_labels != []: 132 | emo_labels = np.concatenate(emo_labels) 133 | if val_preds != []: 134 | val_preds = np.concatenate(val_preds) 135 | if val_labels != []: 136 | val_labels = np.concatenate(val_labels) 137 | results, _ = calculate_results(emo_probs, emo_labels, val_preds, val_labels) 138 | save_results = dict( 139 | names=vidnames, 140 | loss=np.mean(losses), 141 | **results, 142 | ) 143 | else: 144 | save_results = dict( 145 | loss=np.mean(losses), 146 | ) 147 | return save_results 148 | 149 | 150 | if __name__ == "__main__": 151 | parser = argparse.ArgumentParser() 152 | parser.add_argument( 153 | "--seed", 154 | type=int, 155 | default=0, 156 | help="random seed", 157 | ) 158 | 159 | # --- data config 160 | parser.add_argument( 161 | "--dataset", 162 | type=str, 163 | default="seed1", 164 | help="dataset info name", 165 | ) 166 | parser.add_argument( 167 | "--emo_rule", 168 | type=str, 169 | default="MER", 170 | help="emo map function from emotion to index", 171 | ) 172 | 173 | # --- save config 174 | parser.add_argument( 175 | "--save_root", 176 | type=str, 177 | default="./saved", 178 | help="save prediction results and models", 179 | ) 180 | parser.add_argument( 181 | "--save_as_time", 182 | action="store_true", 183 | default=False, 184 | help="save suffix as time, default: run", 185 | ) 186 | parser.add_argument( 187 | "--save_model", 188 | action="store_true", 189 | default=False, 190 | help="whether to save model, default: False", 191 | ) 192 | 193 | # --- model config 194 | parser.add_argument( 195 | "--model", 196 | type=str, 197 | default="auto_attention", 198 | help="model name for training ", 199 | ) 200 | parser.add_argument( 201 | "--model_pretrain", 202 | type=str, 203 | default=None, 204 | help="pretrained model path", 205 | ) 206 | 207 | # --- feat config 208 | parser.add_argument( 209 | "--feat", 210 | type=str, 211 | default='["senet50face_UTT","senet50face_UTT","senet50face_UTT"]', 212 | help="use feat", 213 | ) 214 | 215 | # --- train config 216 | parser.add_argument( 217 | "--lr", type=float, default=1e-4, metavar="lr", help="set lr rate" 218 | ) 219 | # parser.add_argument( 220 | # "--lr_end", type=float, default=1e-5, metavar="lr", help="set lr rate" 221 | # ) 222 | parser.add_argument( 223 | "--l2", 224 | type=float, 225 | default=0.00001, 226 | metavar="L2", 227 | help="L2 regularization weight", 228 | ) 229 | parser.add_argument( 230 | "--batch_size", 231 | type=int, 232 | default=512, 233 | metavar="BS", 234 | help="batch size [deal with OOM]", 235 | ) 236 | parser.add_argument( 237 | "--num_workers", type=int, default=4, metavar="nw", help="number of workers" 238 | ) 239 | parser.add_argument( 240 | "--epochs", type=int, default=60, metavar="E", help="number of epochs" 241 | ) 242 | parser.add_argument("--gpu", default=0, type=int, help="GPU id to use") 243 | 244 | args = parser.parse_args() 245 | args.feat = ast.literal_eval(args.feat) 246 | 247 | set_seed(args.seed) 248 | 249 | torch.cuda.set_device(args.gpu) 250 | 251 | print("====== set save dir =======") 252 | now = datetime.datetime.now() 253 | if args.save_as_time: 254 | save_dir = os.path.join( 255 | args.save_root, args.dataset, args.model, now.strftime("%Y-%m-%d-%H-%M-%S") 256 | ) 257 | else: 258 | save_dir = os.path.join(args.save_root, args.dataset, args.model, "run") 259 | if not os.path.exists(save_dir): 260 | os.makedirs(save_dir) 261 | print("{}\n".format(save_dir)) 262 | 263 | logging_path = osp.join(save_dir, "info.log") 264 | logging.basicConfig( 265 | level=logging.INFO, 266 | format="%(levelname)s: %(message)s", 267 | filename=logging_path, 268 | filemode="w", 269 | ) 270 | 271 | logging.info("====== load info and config =======") 272 | dataset_info = np.load( 273 | osp.join(config.PATH_TO_TRAIN_LST, args.dataset + ".npy"), allow_pickle=True 274 | ).item() 275 | 276 | model_config = OmegaConf.load("models/model-tune.yaml")[args.model] 277 | model_config = OmegaConf.to_container(model_config, resolve=True) 278 | 279 | emo_rule = config.EMO_RULE[args.emo_rule] 280 | model_config["emo_rule"] = emo_rule 281 | model_config["output_dim1"] = len(emo_rule) 282 | model_config["output_dim2"] = 0 283 | 284 | feat_dims = [] 285 | feat_roots = [] 286 | feat_types = [] 287 | for use_feat in args.feat: 288 | if use_feat in config.FEAT_VIDEO_DICT: 289 | feat_dims.append(config.FEAT_VIDEO_DICT[use_feat][0]) 290 | feat_roots.append(config.FEAT_VIDEO_DICT[use_feat][1]) 291 | feat_types.append("V") 292 | elif use_feat in config.FEAT_AUDIO_DICT: 293 | feat_dims.append(config.FEAT_AUDIO_DICT[use_feat][0]) 294 | feat_roots.append(config.FEAT_AUDIO_DICT[use_feat][1]) 295 | feat_types.append("A") 296 | elif use_feat in config.FEAT_TEXT_DICT: 297 | feat_dims.append(config.FEAT_TEXT_DICT[use_feat][0]) 298 | feat_roots.append(config.FEAT_TEXT_DICT[use_feat][1]) 299 | feat_types.append("T") 300 | model_config["feat_dims"] = feat_dims 301 | model_config["feat_roots"] = feat_roots 302 | for feat_type, feat_dim, feat_root in zip(feat_types, feat_dims, feat_roots): 303 | logging.info("Modality:{}, Input feature: {}".format(feat_type, feat_root)) 304 | logging.info("===> dim is (1, {}) \n".format(feat_dim)) 305 | 306 | args = merge_args_config(args, model_config) # 两部分参数重叠 307 | logging.info("save config: {} \n".format(osp.join(save_dir, "args.yaml"))) 308 | OmegaConf.save(vars(args), osp.join(save_dir, "args.yaml")) 309 | 310 | logging.info("====== load dataset =======") 311 | train_info = dataset_info["train_info"] 312 | train_loader = get_dataloader_train( 313 | names=train_info["names"], 314 | emo_labels=train_info["emos"], 315 | val_labels=train_info["vals"], 316 | emo_rule=emo_rule, 317 | feat_roots=args.feat_roots, 318 | batch_size=args.batch_size, 319 | num_workers=args.num_workers, 320 | shuffle=True, 321 | ) 322 | valid_info = dataset_info["valid_info"] 323 | valid_loader = get_dataloader_valid( 324 | names=valid_info["names"], 325 | emo_labels=valid_info["emos"], 326 | val_labels=valid_info["vals"], 327 | emo_rule=emo_rule, 328 | feat_roots=args.feat_roots, 329 | batch_size=args.batch_size, 330 | num_workers=args.num_workers, 331 | shuffle=False, 332 | ) 333 | 334 | logging.info("====== build model =======") 335 | model = get_models(args).cuda() 336 | cls_loss = CELoss().cuda() 337 | reg_loss = MSELoss().cuda() 338 | optimizer = optim.Adam( 339 | (param for param in model.parameters() if param.requires_grad), 340 | lr=args.lr, 341 | weight_decay=args.l2, 342 | ) 343 | logging.info("load model: {} \n".format(args.model)) 344 | 345 | if args.model_pretrain is not None: 346 | model.load_state_dict(torch.load(args.model_pretrain, map_location=f"cuda:0")) 347 | 348 | logging.info("====== Training and Evaluation =======") 349 | best_eval_metric = None # for select best model weight 350 | record_train = {"epoch": [], "loss": []} 351 | record_valid = {"epoch": [], "loss": [], "emoacc": [], "emofscore": []} 352 | best_valid = {} 353 | for epoch in range(args.epochs): 354 | logging.info( 355 | "epoch: {}, lr:{:.8f}".format( 356 | epoch + 1, optimizer.state_dict()["param_groups"][0]["lr"] 357 | ) 358 | ) 359 | 360 | train_results = train_or_eval_model( 361 | args, 362 | model, 363 | reg_loss, 364 | cls_loss, 365 | train_loader, 366 | epoch=epoch, 367 | optimizer=optimizer, 368 | train=True, 369 | ) 370 | valid_results = train_or_eval_model( 371 | args, 372 | model, 373 | reg_loss, 374 | cls_loss, 375 | valid_loader, 376 | epoch=epoch, 377 | optimizer=None, 378 | train=False, 379 | ) 380 | 381 | # --- logging info 382 | logging_info = "\tTRAIN: loss:{:.4f}".format(train_results["loss"]) 383 | logging.info(logging_info) 384 | 385 | logging_info = "\tVALID: loss:{:.4f}, acc:{:.4f}, f-score:{:.4f}".format( 386 | valid_results["loss"], 387 | valid_results["emoacc"], 388 | valid_results["emofscore"], 389 | ) 390 | logging.info(logging_info) 391 | 392 | # --- record data and plot 393 | record_train["epoch"].append(epoch + 1) 394 | record_train["loss"].append(train_results["loss"]) 395 | 396 | record_valid["epoch"].append(epoch + 1) 397 | for key in record_valid: 398 | if key == "epoch": 399 | continue 400 | record_valid[key].append(valid_results[key]) 401 | 402 | # --- save best model 403 | ## select metric 404 | eval_metric = valid_results["emofscore"] 405 | 406 | # save eval results, model config and weight 407 | if best_eval_metric is None or best_eval_metric < eval_metric: 408 | best_eval_metric = eval_metric 409 | # save best result info in valid dataset 410 | best_valid["epoch"] = epoch + 1 411 | for key in record_valid: 412 | if key in ["epoch"]: 413 | continue 414 | best_valid[key] = record_valid[key][-1] 415 | 416 | # save best result in valid dataset 417 | epoch_store = {} 418 | func_update_storage( 419 | inputs=valid_results, prefix="eval", outputs=epoch_store 420 | ) 421 | np.save(osp.join(save_dir, "best_valid_results.npy"), epoch_store) 422 | 423 | # save best model weight and args 424 | np.save(osp.join(save_dir, "best_args.npy"), {"args": args}) 425 | if args.save_model: 426 | save_path = f"{save_dir}/best.pth" 427 | save_model(model, save_path) 428 | 429 | logging.info("\t*** Update best info ! ***") 430 | 431 | draw_loss( 432 | record_train["epoch"], 433 | record_train["loss"], 434 | record_valid["loss"], 435 | osp.join(save_dir, "loss.png"), 436 | ) 437 | for key in record_valid: 438 | if key in ["epoch", "loss"]: 439 | continue 440 | draw_metric( 441 | record_valid["epoch"], 442 | record_valid[key], 443 | key, 444 | osp.join(save_dir, "{}.png".format(key)), 445 | ) 446 | 447 | logging.info("End Training ! \n\n") 448 | logging_info = "BEST: epoch:{:.0f}, loss:{:.4f}, acc:{:.4f}, f-score:{:.4f}".format( 449 | best_valid["epoch"], 450 | best_valid["loss"], 451 | best_valid["emoacc"], 452 | best_valid["emofscore"], 453 | ) 454 | logging.info(logging_info) 455 | 456 | # clear memory 457 | del model 458 | del optimizer 459 | torch.cuda.empty_cache() 460 | -------------------------------------------------------------------------------- /models/emo_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Encoder(nn.Module): 6 | 7 | def __init__( 8 | self, inp_dim, out_dim, dropout, mid_dim=None, num_heads=4, num_layers=2 9 | ): 10 | 11 | super(Encoder, self).__init__() 12 | if isinstance(inp_dim, list): 13 | self.mode = len(inp_dim) 14 | else: 15 | self.mode = 1 16 | 17 | if mid_dim is None: 18 | mid_dim = out_dim 19 | if self.mode == 1: 20 | # self.norm = nn.BatchNorm1d(in_size) 21 | self.linear = nn.Linear(inp_dim, mid_dim) 22 | self.net = nn.TransformerEncoder( 23 | nn.TransformerEncoderLayer( 24 | d_model=mid_dim, 25 | nhead=num_heads, 26 | dim_feedforward=2048, 27 | dropout=dropout, 28 | batch_first=True, 29 | ), 30 | num_layers=num_layers, 31 | ) 32 | self.linear2 = nn.Linear(mid_dim, out_dim) 33 | elif self.mode == 2: 34 | self.linear = nn.Linear(35, mid_dim) 35 | self.cls_token = nn.Parameter(torch.zeros(1, 1, mid_dim)) 36 | self.net = nn.TransformerEncoder( 37 | nn.TransformerEncoderLayer( 38 | d_model=mid_dim, 39 | nhead=num_heads, 40 | dim_feedforward=2048, 41 | dropout=dropout, 42 | batch_first=True, 43 | ), 44 | num_layers=num_layers, 45 | ) 46 | self.linear2 = nn.Linear(mid_dim, out_dim) 47 | 48 | def forward(self, x): 49 | """ 50 | Args: 51 | x: tensor of shape (batch_size, in_size) 52 | """ 53 | # normed = self.norm(x) 54 | 55 | if self.mode == 1: 56 | x = self.linear(x) 57 | x = self.net(x) 58 | elif self.mode == 2: 59 | x = x[:, :, 674:] # only AU feature in openface features 60 | x = self.linear(x) 61 | B, L, C = x.shape 62 | cls_token = self.cls_token.expand(B, -1, -1) # (B, 1, mid_dim) 63 | x = torch.cat((cls_token, x), dim=1) # (B, L+1, mid_dim) 64 | x = self.net(x) # (B, L+1, mid_dim) 65 | x = x[:, 0, :] 66 | 67 | x = self.linear2(x) 68 | 69 | return x 70 | 71 | 72 | class FcClassifier(nn.Module): 73 | def __init__(self, hidden_dim, cls_layers, output_dim=6, dropout=0, use_bn=True): 74 | super(FcClassifier, self).__init__() 75 | self.fc_layers = nn.ModuleList() 76 | self.fc_layers.append(nn.Linear(hidden_dim, cls_layers[0])) 77 | if use_bn is True: 78 | self.fc_layers.append(nn.BatchNorm1d(cls_layers[0])) 79 | self.fc_layers.append(nn.ReLU(inplace=True)) 80 | if dropout > 0: 81 | self.fc_layers.append(nn.Dropout(dropout)) 82 | 83 | for i in range(1, len(cls_layers)): 84 | self.fc_layers.append(nn.Linear(cls_layers[i - 1], cls_layers[i])) 85 | if use_bn is True: 86 | self.fc_layers.append(nn.BatchNorm1d(cls_layers[i])) 87 | self.fc_layers.append(nn.ReLU(inplace=True)) 88 | if dropout > 0: 89 | self.fc_layers.append(nn.Dropout(dropout)) 90 | 91 | self.output_layer = nn.Linear(cls_layers[-1], output_dim) 92 | 93 | def forward(self, x): 94 | for layer in self.fc_layers: 95 | x = layer(x) 96 | x = self.output_layer(x) 97 | return x 98 | 99 | 100 | class Auto_WeightV1(nn.Module): 101 | def __init__(self, args): 102 | super(Auto_WeightV1, self).__init__() 103 | feat_dims = args.feat_dims 104 | output_dim1 = args.output_dim1 105 | output_dim2 = args.output_dim2 106 | dropout = args.dropout 107 | hidden_dim = args.hidden_dim 108 | self.grad_clip = args.grad_clip 109 | self.feat_dims = feat_dims 110 | 111 | num_heads = args.num_heads 112 | num_layers = args.num_layers 113 | 114 | assert args.feat_type == "utt" 115 | 116 | assert len(feat_dims) <= 3 * 6 117 | if len(feat_dims) >= 1: 118 | self.encoder0 = Encoder( 119 | feat_dims[0], 120 | hidden_dim, 121 | dropout, 122 | num_heads=num_heads, 123 | num_layers=num_layers, 124 | ) 125 | self.cls0 = FcClassifier( 126 | hidden_dim, 127 | [256, 128], 128 | output_dim=output_dim1, 129 | dropout=0, 130 | use_bn=True, 131 | ) 132 | if len(feat_dims) >= 2: 133 | self.encoder1 = Encoder( 134 | feat_dims[1], 135 | hidden_dim, 136 | dropout, 137 | num_heads=num_heads, 138 | num_layers=num_layers, 139 | ) 140 | self.cls1 = FcClassifier( 141 | hidden_dim, 142 | [256, 128], 143 | output_dim=output_dim1, 144 | dropout=0, 145 | use_bn=True, 146 | ) 147 | if len(feat_dims) >= 3: 148 | self.encoder2 = Encoder( 149 | feat_dims[2], 150 | hidden_dim, 151 | dropout, 152 | num_heads=num_heads, 153 | num_layers=num_layers, 154 | ) 155 | self.cls2 = FcClassifier( 156 | hidden_dim, 157 | [256, 128], 158 | output_dim=output_dim1, 159 | dropout=0, 160 | use_bn=True, 161 | ) 162 | if len(feat_dims) >= 4: 163 | self.encoder3 = Encoder( 164 | feat_dims[3], 165 | hidden_dim, 166 | dropout, 167 | num_heads=num_heads, 168 | num_layers=num_layers, 169 | ) 170 | self.cls3 = FcClassifier( 171 | hidden_dim, 172 | [256, 128], 173 | output_dim=output_dim1, 174 | dropout=0, 175 | use_bn=True, 176 | ) 177 | if len(feat_dims) >= 5: 178 | self.encoder4 = Encoder( 179 | feat_dims[4], 180 | hidden_dim, 181 | dropout, 182 | num_heads=num_heads, 183 | num_layers=num_layers, 184 | ) 185 | self.cls4 = FcClassifier( 186 | hidden_dim, 187 | [256, 128], 188 | output_dim=output_dim1, 189 | dropout=0, 190 | use_bn=True, 191 | ) 192 | if len(feat_dims) >= 6: 193 | self.encoder5 = Encoder( 194 | feat_dims[5], 195 | hidden_dim, 196 | dropout, 197 | num_heads=num_heads, 198 | num_layers=num_layers, 199 | ) 200 | self.cls5 = FcClassifier( 201 | hidden_dim, 202 | [256, 128], 203 | output_dim=output_dim1, 204 | dropout=0, 205 | use_bn=True, 206 | ) 207 | if len(feat_dims) >= 7: 208 | self.encoder6 = Encoder( 209 | feat_dims[6], 210 | hidden_dim, 211 | dropout, 212 | num_heads=num_heads, 213 | num_layers=num_layers, 214 | ) 215 | self.cls6 = FcClassifier( 216 | hidden_dim, 217 | [256, 128], 218 | output_dim=output_dim1, 219 | dropout=0, 220 | use_bn=True, 221 | ) 222 | if len(feat_dims) >= 8: 223 | self.encoder7 = Encoder( 224 | feat_dims[7], 225 | hidden_dim, 226 | dropout, 227 | num_heads=num_heads, 228 | num_layers=num_layers, 229 | ) 230 | self.cls7 = FcClassifier( 231 | hidden_dim, 232 | [256, 128], 233 | output_dim=output_dim1, 234 | dropout=0, 235 | use_bn=True, 236 | ) 237 | if len(feat_dims) >= 9: 238 | self.encoder8 = Encoder( 239 | feat_dims[8], 240 | hidden_dim, 241 | dropout, 242 | num_heads=num_heads, 243 | num_layers=num_layers, 244 | ) 245 | self.cls8 = FcClassifier( 246 | hidden_dim, 247 | [256, 128], 248 | output_dim=output_dim1, 249 | dropout=0, 250 | use_bn=True, 251 | ) 252 | if len(feat_dims) >= 10: 253 | self.encoder9 = Encoder( 254 | feat_dims[9], 255 | hidden_dim, 256 | dropout, 257 | num_heads=num_heads, 258 | num_layers=num_layers, 259 | ) 260 | self.cls9 = FcClassifier( 261 | hidden_dim, 262 | [256, 128], 263 | output_dim=output_dim1, 264 | dropout=0, 265 | use_bn=True, 266 | ) 267 | if len(feat_dims) >= 11: 268 | self.encoder10 = Encoder( 269 | feat_dims[10], 270 | hidden_dim, 271 | dropout, 272 | num_heads=num_heads, 273 | num_layers=num_layers, 274 | ) 275 | self.cls10 = FcClassifier( 276 | hidden_dim, 277 | [256, 128], 278 | output_dim=output_dim1, 279 | dropout=0, 280 | use_bn=True, 281 | ) 282 | if len(feat_dims) >= 12: 283 | self.encoder11 = Encoder( 284 | feat_dims[11], 285 | hidden_dim, 286 | dropout, 287 | num_heads=num_heads, 288 | num_layers=num_layers, 289 | ) 290 | self.cls11 = FcClassifier( 291 | hidden_dim, 292 | [256, 128], 293 | output_dim=output_dim1, 294 | dropout=0, 295 | use_bn=True, 296 | ) 297 | if len(feat_dims) >= 13: 298 | self.encoder12 = Encoder( 299 | feat_dims[12], 300 | hidden_dim, 301 | dropout, 302 | num_heads=num_heads, 303 | num_layers=num_layers, 304 | ) 305 | self.cls12 = FcClassifier( 306 | hidden_dim, 307 | [256, 128], 308 | output_dim=output_dim1, 309 | dropout=0, 310 | use_bn=True, 311 | ) 312 | if len(feat_dims) >= 14: 313 | self.encoder13 = Encoder( 314 | feat_dims[13], 315 | hidden_dim, 316 | dropout, 317 | num_heads=num_heads, 318 | num_layers=num_layers, 319 | ) 320 | self.cls13 = FcClassifier( 321 | hidden_dim, 322 | [256, 128], 323 | output_dim=output_dim1, 324 | dropout=0, 325 | use_bn=True, 326 | ) 327 | if len(feat_dims) >= 15: 328 | self.encoder14 = Encoder( 329 | feat_dims[14], 330 | hidden_dim, 331 | dropout, 332 | num_heads=num_heads, 333 | num_layers=num_layers, 334 | ) 335 | self.cls14 = FcClassifier( 336 | hidden_dim, 337 | [256, 128], 338 | output_dim=output_dim1, 339 | dropout=0, 340 | use_bn=True, 341 | ) 342 | 343 | # --- feature attention 344 | self.attention_mlp = Encoder( 345 | hidden_dim * len(feat_dims), 346 | hidden_dim, 347 | dropout, 348 | num_heads=num_heads, 349 | num_layers=num_layers, 350 | ) 351 | self.fc_att = nn.Linear(hidden_dim, len(feat_dims)) 352 | 353 | # --- logit weight 354 | self.weight_fc = nn.Sequential( 355 | nn.Linear(hidden_dim * len(feat_dims), hidden_dim), 356 | nn.GELU(), 357 | nn.Linear(hidden_dim, len(feat_dims) + 1), 358 | ) 359 | 360 | self.fc_out_1 = FcClassifier( 361 | hidden_dim * len(feat_dims), 362 | [256, 128], 363 | output_dim=output_dim1, 364 | dropout=0, 365 | use_bn=True, 366 | ) 367 | self.fc_out_2 = FcClassifier( 368 | hidden_dim * len(feat_dims), 369 | [256, 128], 370 | output_dim=output_dim2, 371 | dropout=0, 372 | use_bn=True, 373 | ) 374 | 375 | self.criterion = nn.CrossEntropyLoss() 376 | 377 | def forward(self, inp): 378 | """ 379 | support feat_type: utt | frm-align | frm-unalign 380 | """ 381 | 382 | [batch, emos] = inp 383 | 384 | hiddens = [] 385 | logits = [] 386 | if len(self.feat_dims) >= 1: 387 | hiddens.append(self.encoder0(batch[f"feat0"])) 388 | logits.append(self.cls0(hiddens[0])) 389 | if len(self.feat_dims) >= 2: 390 | hiddens.append(self.encoder1(batch[f"feat1"])) 391 | logits.append(self.cls1(hiddens[1])) 392 | if len(self.feat_dims) >= 3: 393 | hiddens.append(self.encoder2(batch[f"feat2"])) 394 | logits.append(self.cls2(hiddens[2])) 395 | if len(self.feat_dims) >= 4: 396 | hiddens.append(self.encoder3(batch[f"feat3"])) 397 | logits.append(self.cls3(hiddens[3])) 398 | if len(self.feat_dims) >= 5: 399 | hiddens.append(self.encoder4(batch[f"feat4"])) 400 | logits.append(self.cls4(hiddens[4])) 401 | if len(self.feat_dims) >= 6: 402 | hiddens.append(self.encoder5(batch[f"feat5"])) 403 | logits.append(self.cls5(hiddens[5])) 404 | if len(self.feat_dims) >= 7: 405 | hiddens.append(self.encoder6(batch[f"feat6"])) 406 | logits.append(self.cls6(hiddens[6])) 407 | if len(self.feat_dims) >= 8: 408 | hiddens.append(self.encoder7(batch[f"feat7"])) 409 | logits.append(self.cls7(hiddens[7])) 410 | if len(self.feat_dims) >= 9: 411 | hiddens.append(self.encoder8(batch[f"feat8"])) 412 | logits.append(self.cls8(hiddens[8])) 413 | if len(self.feat_dims) >= 10: 414 | hiddens.append(self.encoder9(batch[f"feat9"])) 415 | logits.append(self.cls9(hiddens[9])) 416 | if len(self.feat_dims) >= 11: 417 | hiddens.append(self.encoder10(batch[f"feat10"])) 418 | logits.append(self.cls10(hiddens[10])) 419 | if len(self.feat_dims) >= 12: 420 | hiddens.append(self.encoder11(batch[f"feat11"])) 421 | logits.append(self.cls11(hiddens[11])) 422 | if len(self.feat_dims) >= 13: 423 | hiddens.append(self.encoder12(batch[f"feat12"])) 424 | logits.append(self.cls12(hiddens[12])) 425 | if len(self.feat_dims) >= 14: 426 | hiddens.append(self.encoder13(batch[f"feat13"])) 427 | logits.append(self.cls13(hiddens[13])) 428 | if len(self.feat_dims) >= 15: 429 | hiddens.append(self.encoder14(batch[f"feat14"])) 430 | logits.append(self.cls14(hiddens[14])) 431 | 432 | multi_hidden1 = torch.cat(hiddens, dim=1) # [32, 384] 433 | 434 | logits.append(self.fc_out_1(multi_hidden1)) 435 | vals_out = self.fc_out_2(multi_hidden1) 436 | 437 | weight = self.weight_fc(multi_hidden1).unsqueeze(-1) 438 | 439 | multi_logits = torch.stack(logits, dim=2).permute(0, 2, 1) # [32, N, 6] 440 | emos_out = (weight * multi_logits).sum(dim=1) 441 | 442 | if self.training: 443 | interloss = self.cal_loss( 444 | logits, 445 | emos, 446 | ) 447 | else: 448 | interloss = torch.tensor(0).cuda() 449 | 450 | return None, emos_out, vals_out, interloss 451 | 452 | def cal_loss( 453 | self, 454 | logits, 455 | emos, 456 | ): 457 | # feat_A_concat, feat_V_concat, feat_L_concat): 458 | emos = emos.to(logits[0].device) 459 | 460 | loss = 0 461 | for logit in logits: 462 | loss += self.criterion(logit, emos) / len(logits) 463 | 464 | return loss 465 | -------------------------------------------------------------------------------- /feature_extraction/text/extract_text_huggingface.py: -------------------------------------------------------------------------------- 1 | # *_*coding:utf-8 *_* 2 | import os 3 | import time 4 | import argparse 5 | import numpy as np 6 | import pandas as pd 7 | 8 | import torch 9 | from transformers import ( 10 | AutoModel, 11 | BertTokenizer, 12 | AutoTokenizer, 13 | ) # version: 4.5.1, pip install transformers 14 | from transformers import GPT2Tokenizer, GPT2Model, AutoModelForCausalLM 15 | 16 | # local folder 17 | import sys 18 | 19 | current_file_path = os.path.abspath(__file__) 20 | sys.path.append(os.path.dirname(os.path.dirname(current_file_path))) 21 | import config 22 | 23 | ##################### English ##################### 24 | BERT_BASE = "bert-base-cased" 25 | BERT_LARGE = "bert-large-cased" 26 | BERT_BASE_UNCASED = "bert-base-uncased" 27 | BERT_LARGE_UNCASED = "bert-large-uncased" 28 | ALBERT_BASE = "albert-base-v2" 29 | ALBERT_LARGE = "albert-large-v2" 30 | ALBERT_XXLARGE = "albert-xxlarge-v2" 31 | ROBERTA_BASE = "roberta-base" 32 | ROBERTA_LARGE = "roberta-large" 33 | ELECTRA_BASE = "electra-base-discriminator" 34 | ELECTRA_LARGE = "electra-large-discriminator" 35 | XLNET_BASE = "xlnet-base-cased" 36 | XLNET_LARGE = "xlnet-large-cased" 37 | T5_BASE = "t5-base" 38 | T5_LARGE = "t5-large" 39 | DEBERTA_BASE = "deberta-base" 40 | DEBERTA_LARGE = "deberta-large" 41 | DEBERTA_XLARGE = "deberta-v2-xlarge" 42 | DEBERTA_XXLARGE = "deberta-v2-xxlarge" 43 | 44 | ##################### Chinese ##################### 45 | BERT_BASE_CHINESE = "bert-base-chinese" # https://huggingface.co/bert-base-chinese 46 | ROBERTA_BASE_CHINESE = ( 47 | "chinese-roberta-wwm-ext" # https://huggingface.co/hfl/chinese-roberta-wwm-ext 48 | ) 49 | ROBERTA_LARGE_CHINESE = "chinese-roberta-wwm-ext-large" # https://huggingface.co/hfl/chinese-roberta-wwm-ext-large 50 | DEBERTA_LARGE_CHINESE = ( 51 | "deberta-chinese-large" # https://huggingface.co/WENGSYX/Deberta-Chinese-Large 52 | ) 53 | ELECTRA_SMALL_CHINESE = "chinese-electra-180g-small" # https://huggingface.co/hfl/chinese-electra-180g-small-discriminator 54 | ELECTRA_BASE_CHINESE = "chinese-electra-180g-base" # https://huggingface.co/hfl/chinese-electra-180g-base-discriminator 55 | ELECTRA_LARGE_CHINESE = "chinese-electra-180g-large" # https://huggingface.co/hfl/chinese-electra-180g-large-discriminator 56 | XLNET_BASE_CHINESE = ( 57 | "chinese-xlnet-base" # https://huggingface.co/hfl/chinese-xlnet-base 58 | ) 59 | MACBERT_BASE_CHINESE = ( 60 | "chinese-macbert-base" # https://huggingface.co/hfl/chinese-macbert-base 61 | ) 62 | MACBERT_LARGE_CHINESE = ( 63 | "chinese-macbert-large" # https://huggingface.co/hfl/chinese-macbert-large 64 | ) 65 | PERT_BASE_CHINESE = "chinese-pert-base" # https://huggingface.co/hfl/chinese-pert-base 66 | PERT_LARGE_CHINESE = ( 67 | "chinese-pert-large" # https://huggingface.co/hfl/chinese-pert-large 68 | ) 69 | LERT_SMALL_CHINESE = ( 70 | "chinese-lert-small" # https://huggingface.co/hfl/chinese-lert-small 71 | ) 72 | LERT_BASE_CHINESE = "chinese-lert-base" # https://huggingface.co/hfl/chinese-lert-base 73 | LERT_LARGE_CHINESE = ( 74 | "chinese-lert-large" # https://huggingface.co/hfl/chinese-lert-large 75 | ) 76 | GPT2_CHINESE = "gpt2-chinese-cluecorpussmall" # https://huggingface.co/uer/gpt2-chinese-cluecorpussmall 77 | CLIP_CHINESE = "taiyi-clip-roberta-chinese" # https://huggingface.co/IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese 78 | WENZHONG_GPT2_CHINESE = "wenzhong2-gpt2-chinese" # https://huggingface.co/IDEA-CCNL/Wenzhong2.0-GPT2-3.5B-chinese 79 | ALBERT_TINY_CHINESE = ( 80 | "albert_chinese_tiny" # https://huggingface.co/clue/albert_chinese_tiny 81 | ) 82 | ALBERT_SMALL_CHINESE = ( 83 | "albert_chinese_small" # https://huggingface.co/clue/albert_chinese_small 84 | ) 85 | SIMBERT_BASE_CHINESE = ( 86 | "simbert-base-chinese" # https://huggingface.co/WangZeJun/simbert-base-chinese 87 | ) 88 | 89 | ##################### Multilingual ##################### 90 | MPNET_BASE = "paraphrase-multilingual-mpnet-base-v2" # https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2 91 | 92 | ##################### LLM ##################### 93 | LLAMA_7B = "llama-7b-hf" # https://huggingface.co/decapoda-research/llama-7b-hf 94 | LLAMA_13B = "llama-13b-hf" # https://huggingface.co/decapoda-research/llama-13b-hf 95 | LLAMA2_7B = "llama-2-7b" # https://huggingface.co/meta-llama/Llama-2-7b 96 | LLAMA2_13B = "Llama-2-13b-hf" # https://huggingface.co/NousResearch/Llama-2-13b-hf 97 | VICUNA_7B = "vicuna-7b-v0" # https://huggingface.co/lmsys/vicuna-7b-delta-v0 98 | VICUNA_13B = ( 99 | "stable-vicuna-13b" # https://huggingface.co/CarperAI/stable-vicuna-13b-delta 100 | ) 101 | ALPACE_13B = ( 102 | "chinese-alpaca-2-13b" # https://huggingface.co/ziqingyang/chinese-alpaca-2-13b 103 | ) 104 | MOSS_7B = "moss-base-7b" # https://huggingface.co/fnlp/moss-base-7b 105 | STABLEML_7B = "stablelm-base-alpha-7b-v2" # https://huggingface.co/stabilityai/stablelm-base-alpha-7b-v2 106 | BLOOM_7B = "bloom-7b1" # https://huggingface.co/bigscience/bloom-7b1 107 | CHATGLM2_6B = "chatglm2-6b" # https://huggingface.co/THUDM/chatglm2-6b 108 | # reley on pytorch=2.0 => env: videollama4 + cpu 109 | FALCON_7B = "falcon-7b" # https://huggingface.co/tiiuae/falcon-7b 110 | # Baichuan: pip install transformers_stream_generator 111 | BAICHUAN_7B = "Baichuan-7B" # https://huggingface.co/baichuan-inc/Baichuan-7B 112 | BAICHUAN_13B = ( 113 | "Baichuan-13B-Base" # https://huggingface.co/baichuan-inc/Baichuan-13B-Base 114 | ) 115 | # BAICHUAN2_7B: conda install xformers -c xformers 116 | BAICHUAN2_7B = ( 117 | "Baichuan2-7B-Base" # https://huggingface.co/baichuan-inc/Baichuan2-7B-Base 118 | ) 119 | # BAICHUAN2_13B: pip install accelerate 120 | BAICHUAN2_13B = ( 121 | "Baichuan2-13B-Base" # https://huggingface.co/baichuan-inc/Baichuan2-13B-Base 122 | ) 123 | OPT_13B = "opt-13b" # https://huggingface.co/facebook/opt-13b 124 | 125 | 126 | ################################################################ 127 | # 自动删除无意义token对应的特征 128 | def find_start_end_pos(tokenizer): 129 | sentence = "今天天气真好" # 句子中没有空格 130 | input_ids = tokenizer(sentence, return_tensors="pt")["input_ids"][0] 131 | start, end = None, None 132 | 133 | # find start, must in range [0, 1, 2] 134 | for start in range(0, 3, 1): 135 | # 因为decode有时会出现空格,因此我们显示的时候把这部分信息去掉看看 136 | outputs = tokenizer.decode(input_ids[start:]).replace(" ", "") 137 | if outputs == sentence: 138 | print(f"start: {start}; end: {end}") 139 | return start, None 140 | 141 | if outputs.startswith(sentence): 142 | break 143 | 144 | # find end, must in range [-1, -2] 145 | for end in range(-1, -3, -1): 146 | outputs = tokenizer.decode(input_ids[start:end]).replace(" ", "") 147 | if outputs == sentence: 148 | break 149 | 150 | assert tokenizer.decode(input_ids[start:end]).replace(" ", "") == sentence 151 | print(f"start: {start}; end: {end}") 152 | return start, end 153 | 154 | 155 | # 找到 batch_pos and feature_dim 156 | def find_batchpos_embdim(tokenizer, model, gpu): 157 | sentence = "今天天气真好" 158 | inputs = tokenizer(sentence, return_tensors="pt") 159 | if gpu != -1: 160 | inputs = inputs.to("cuda") 161 | 162 | with torch.no_grad(): 163 | outputs = model( 164 | **inputs, output_hidden_states=True 165 | ).hidden_states # for new version 4.5.1 166 | outputs = torch.stack(outputs)[[-1]].sum(dim=0) # sum => [batch, T, D=768] 167 | outputs = outputs.cpu().numpy() # (B, T, D) or (T, B, D) 168 | batch_pos = None 169 | if outputs.shape[0] == 1: 170 | batch_pos = 0 171 | if outputs.shape[1] == 1: 172 | batch_pos = 1 173 | assert batch_pos in [0, 1] 174 | feature_dim = outputs.shape[2] 175 | print(f"batch_pos:{batch_pos}, feature_dim:{feature_dim}") 176 | return batch_pos, feature_dim 177 | 178 | 179 | # main process 180 | def extract_embedding( 181 | model_name, 182 | trans_dir, 183 | save_dir, 184 | feature_level, 185 | gpu=-1, 186 | punc_case=None, 187 | language="chinese", 188 | model_dir=None, 189 | ): 190 | 191 | print("=" * 30 + f' Extracting "{model_name}" ' + "=" * 30) 192 | start_time = time.time() 193 | 194 | # save last four layers 195 | layer_ids = [-4, -3, -2, -1] 196 | 197 | # save_dir 198 | if punc_case is None and language == "chinese" and model_dir is None: 199 | save_dir = os.path.join(save_dir, f"{model_name}-{feature_level[:3]}") 200 | elif punc_case is not None: 201 | save_dir = os.path.join( 202 | save_dir, f"{model_name}-punc{punc_case}-{feature_level[:3]}" 203 | ) 204 | elif language == "english": 205 | save_dir = os.path.join(save_dir, f"{model_name}-langeng-{feature_level[:3]}") 206 | elif model_dir is not None: 207 | prefix_name = "-".join(model_dir.split("/")[-2:]) 208 | save_dir = os.path.join( 209 | save_dir, f"{prefix_name}-{model_name}-{feature_level[:3]}" 210 | ) 211 | if not os.path.exists(save_dir): 212 | os.makedirs(save_dir) 213 | 214 | # load model and tokenizer: offline mode (load cached files) # 函数都一样,但是有些位置的参数就不好压缩 215 | print("Loading pre-trained tokenizer and model...") 216 | if model_dir is None: 217 | model_dir = os.path.join( 218 | config.PATH_TO_PRETRAINED_MODELS, f"transformers/{model_name}" 219 | ) 220 | 221 | if model_name in [DEBERTA_LARGE_CHINESE, ALBERT_TINY_CHINESE, ALBERT_SMALL_CHINESE]: 222 | model = AutoModel.from_pretrained(model_dir) 223 | tokenizer = BertTokenizer.from_pretrained(model_dir, use_fast=False) 224 | elif model_name in [WENZHONG_GPT2_CHINESE]: 225 | model = GPT2Model.from_pretrained(model_dir) 226 | tokenizer = GPT2Tokenizer.from_pretrained(model_dir, use_fast=False) 227 | elif model_name in [ 228 | LLAMA_7B, 229 | LLAMA_13B, 230 | LLAMA2_7B, 231 | VICUNA_7B, 232 | VICUNA_13B, 233 | ALPACE_13B, 234 | OPT_13B, 235 | BLOOM_7B, 236 | ]: 237 | model = AutoModel.from_pretrained(model_dir) 238 | tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False) 239 | elif model_name in [LLAMA2_13B]: 240 | model = AutoModel.from_pretrained(model_dir, use_safetensors=False) 241 | tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False) 242 | elif model_name in [CHATGLM2_6B, MOSS_7B]: 243 | model = AutoModel.from_pretrained(model_dir, trust_remote_code=True) 244 | if model_dir in [ 245 | "pretrain-guhao/merged_chatglm2_2", 246 | "pretrain-guhao/merged_chatglm2_3", 247 | ]: 248 | tokenizer = AutoTokenizer.from_pretrained( 249 | "pretrain-guhao/merged_chatglm2", use_fast=False, trust_remote_code=True 250 | ) 251 | else: 252 | tokenizer = AutoTokenizer.from_pretrained( 253 | model_dir, use_fast=False, trust_remote_code=True 254 | ) 255 | elif model_name in [BAICHUAN_7B, BAICHUAN_13B, BAICHUAN2_7B, BAICHUAN2_13B]: 256 | model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True) 257 | tokenizer = AutoTokenizer.from_pretrained( 258 | model_dir, use_fast=False, trust_remote_code=True 259 | ) 260 | elif model_name in [STABLEML_7B]: 261 | model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True) 262 | tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) 263 | else: 264 | model = AutoModel.from_pretrained(model_dir) 265 | tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False) 266 | 267 | # 有 gpu 并且是 LLM,才会增加 half process 268 | if gpu != -1 and model_name in [ 269 | LLAMA_7B, 270 | LLAMA_13B, 271 | LLAMA2_7B, 272 | LLAMA2_13B, 273 | VICUNA_7B, 274 | VICUNA_13B, 275 | ALPACE_13B, 276 | OPT_13B, 277 | BLOOM_7B, 278 | CHATGLM2_6B, 279 | MOSS_7B, 280 | BAICHUAN_7B, 281 | FALCON_7B, 282 | BAICHUAN_13B, 283 | STABLEML_7B, 284 | BAICHUAN2_7B, 285 | BAICHUAN2_13B, 286 | ]: 287 | model = model.half() 288 | 289 | # 有 gpu 才会放在cuda上 290 | if gpu != -1: 291 | torch.cuda.set_device(gpu) 292 | model.cuda() 293 | model.eval() 294 | 295 | load_time = time.time() 296 | 297 | print("Calculate embeddings...") 298 | start, end = find_start_end_pos(tokenizer) # only preserve [start:end+1] tokens 299 | batch_pos, feature_dim = find_batchpos_embdim( 300 | tokenizer, model, gpu 301 | ) # find batch pos 302 | 303 | df = pd.read_csv(trans_dir) 304 | for idx, row in df.iterrows(): 305 | name = row["name"] 306 | # -------------------------------------------------- 307 | if language == "chinese": 308 | sentence = row["chinese"] # process on Chinese 309 | elif language == "english": 310 | sentence = row["english"] 311 | # -------------------------------------------------- 312 | print(f"Processing {name} ({idx}/{len(df)})...") 313 | 314 | # extract embedding from sentences 315 | embeddings = [] 316 | if pd.isna(sentence) == False and len(sentence) > 0: 317 | inputs = tokenizer(sentence, return_tensors="pt") 318 | if gpu != -1: 319 | inputs = inputs.to("cuda") 320 | with torch.no_grad(): 321 | outputs = model( 322 | **inputs, output_hidden_states=True 323 | ).hidden_states # for new version 4.5.1 324 | outputs = torch.stack(outputs)[layer_ids].sum( 325 | dim=0 326 | ) # sum => [batch, T, D=768] 327 | outputs = outputs.cpu().numpy() # (B, T, D) 328 | if batch_pos == 0: 329 | embeddings = outputs[0, start:end] 330 | elif batch_pos == 1: 331 | embeddings = outputs[start:end, 0] 332 | 333 | # align with label timestamp and write csv file 334 | print(f"feature dimension: {feature_dim}") 335 | csv_file = os.path.join(save_dir, f"{name}.npy") 336 | if feature_level == "FRAME": 337 | embeddings = np.array(embeddings).squeeze() 338 | if len(embeddings) == 0: 339 | embeddings = np.zeros((1, feature_dim)) 340 | elif len(embeddings.shape) == 1: 341 | embeddings = embeddings[np.newaxis, :] 342 | np.save(csv_file, embeddings) 343 | else: 344 | embeddings = np.array(embeddings).squeeze() 345 | if len(embeddings) == 0: 346 | embeddings = np.zeros((feature_dim,)) 347 | elif len(embeddings.shape) == 2: 348 | embeddings = np.mean(embeddings, axis=0) 349 | np.save(csv_file, embeddings) 350 | 351 | end_time = time.time() 352 | print(f"Model load time used: {load_time - start_time:.1f}s.") 353 | print( 354 | f"Total {len(df)} files done! Time used ({model_name}): {end_time - start_time:.1f}s." 355 | ) 356 | 357 | 358 | if __name__ == "__main__": 359 | 360 | parser = argparse.ArgumentParser(description="Run.") 361 | parser.add_argument( 362 | "--dataset", 363 | type=str, 364 | default="/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/", 365 | help="input dataset", 366 | ) 367 | parser.add_argument("--gpu", type=int, default="0", help="gpu id") 368 | parser.add_argument( 369 | "--model_name", type=str, default="bloom-7b1", help="name of pretrained model" 370 | ) 371 | parser.add_argument( 372 | "--feature_level", 373 | type=str, 374 | default="UTTERANCE", 375 | choices=["UTTERANCE", "FRAME"], 376 | help="output types", 377 | ) 378 | # ------ 临时测试标点符号对于结果的影响 ------ 379 | parser.add_argument( 380 | "--punc_case", 381 | type=str, 382 | default=None, 383 | help="test punc impact to the performance", 384 | ) 385 | # ------ 临时Language对于结果的影响 ------ 386 | parser.add_argument("--language", type=str, default="chinese", help="used language") 387 | # ------ 临时测试外部接受的 model_dir [for gu hao] ------ 388 | parser.add_argument( 389 | "--model_dir", type=str, default=None, help="used user-defined model_dir" 390 | ) 391 | args = parser.parse_args() 392 | 393 | # (trans_dir, save_dir) 394 | if args.punc_case is None: 395 | trans_dir = os.path.join(args.dataset, "text", "transcription.csv") 396 | 397 | save_dir = os.path.join(args.dataset, "features") 398 | 399 | extract_embedding( 400 | model_name=args.model_name, 401 | trans_dir=trans_dir, 402 | save_dir=save_dir, 403 | feature_level=args.feature_level, 404 | gpu=args.gpu, 405 | punc_case=args.punc_case, 406 | language=args.language, 407 | model_dir=args.model_dir, 408 | ) 409 | --------------------------------------------------------------------------------