├── .gitignore
├── images
└── fig1.png
├── models
├── model-tune.yaml
├── __init__.py
├── modules
│ └── encoder.py
├── auto_attention.py
└── emo_net.py
├── preprocess
├── make_lst.sh
├── check_feat_dim.py
├── check_single_feat_dim.py
├── mer2025_base.py
└── mer2024_base.py
├── toolkit
└── utils
│ ├── read_files.py
│ ├── eval.py
│ ├── draw_process.py
│ ├── functions.py
│ ├── metric.py
│ └── loss.py
├── feature_extraction
├── run.sh
├── config.py
├── audio
│ ├── split_audio.py
│ └── extract_audio_huggingface.py
├── visual
│ ├── split_video.py
│ ├── util.py
│ ├── extract_openface.py
│ └── extract_vision_huggingface.py
└── text
│ ├── split_asr.py
│ └── extract_text_huggingface.py
├── run_eval_unimodal.sh
├── eval_avg_performance.py
├── environment.yml
├── prompt.txt
├── dataloader
└── dataloader.py
├── README.md
├── config.py
├── test.py
├── LICENSE
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | saved/
2 | __pycache__/
3 | .DS_Store
--------------------------------------------------------------------------------
/images/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhuyjan/MER2025-MRAC25/HEAD/images/fig1.png
--------------------------------------------------------------------------------
/models/model-tune.yaml:
--------------------------------------------------------------------------------
1 | auto_attention:
2 | train_input_mode: input
3 | feat_type: utt
4 | hidden_dim: 256
5 | dropout: 0.2
6 | grad_clip: -1.0
7 |
--------------------------------------------------------------------------------
/preprocess/make_lst.sh:
--------------------------------------------------------------------------------
1 | python mer2025_base.py --seed 0
2 | python mer2025_base.py --seed 1
3 | python mer2025_base.py --seed 2
4 | python mer2025_base.py --seed 3
5 | python mer2025_base.py --seed 4
--------------------------------------------------------------------------------
/preprocess/check_feat_dim.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | from glob import glob
4 |
5 | import numpy as np
6 |
7 | if __name__ == "__main__":
8 | feat_root = "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/"
9 | flst = glob(osp.join(feat_root, "*"))
10 |
11 | for feat_dir in flst:
12 | feat_name = osp.basename(feat_dir)
13 | fpath = glob(osp.join(feat_dir, "*"))[0]
14 | feat_dim = np.load(fpath).shape[0]
15 |
16 | print("{}, dim={}".format(feat_name, feat_dim))
17 |
--------------------------------------------------------------------------------
/toolkit/utils/read_files.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import math
4 | import random
5 | import numpy as np
6 | import pandas as pd
7 |
8 |
9 | # 功能3:从csv中读取特定的key对应的值
10 | def func_read_key_from_csv(csv_path, key):
11 | values = []
12 | df = pd.read_csv(csv_path)
13 | for _, row in df.iterrows():
14 | if key not in row:
15 | values.append("")
16 | else:
17 | value = row[key]
18 | if pd.isna(value):
19 | value = ""
20 | values.append(value)
21 | return values
22 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .auto_attention import Auto_Attention
4 |
5 |
6 | class get_models(nn.Module):
7 | def __init__(self, args):
8 | super(get_models, self).__init__()
9 | # misa/mmim在有些参数配置下会存在梯度爆炸的风险
10 | # tfn 显存占比比较高
11 |
12 | MODEL_MAP = {
13 | # 特征压缩到句子级再处理,所以支持 utt/align/unalign
14 | "auto_attention": Auto_Attention,
15 | }
16 | self.model = MODEL_MAP[args.model](args)
17 |
18 | def forward(self, batch):
19 | return self.model(batch)
20 |
--------------------------------------------------------------------------------
/toolkit/utils/eval.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from sklearn.metrics import mean_squared_error
4 | from sklearn.metrics import f1_score, accuracy_score
5 |
6 |
7 | def calculate_results(emo_probs=[], emo_labels=[], val_preds=[], val_labels=[]):
8 |
9 | emo_preds = np.argmax(emo_probs, 1)
10 | emo_accuracy = accuracy_score(emo_labels, emo_preds)
11 | emo_fscore = f1_score(emo_labels, emo_preds, average="weighted")
12 |
13 | results = {
14 | "emoprobs": emo_probs,
15 | "emolabels": emo_labels,
16 | "emoacc": emo_accuracy,
17 | "emofscore": emo_fscore,
18 | }
19 | outputs = f"f1:{emo_fscore:.4f}_acc:{emo_accuracy:.4f}"
20 | return results, outputs
21 |
--------------------------------------------------------------------------------
/preprocess/check_single_feat_dim.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | from glob import glob
4 |
5 | import numpy as np
6 |
7 | if __name__ == "__main__":
8 | feat_root = "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/InternVL_2_5_HiCo_R16-UTT"
9 | npy_files = glob(osp.join(feat_root, "*.npy"))
10 |
11 | print(f"Found {len(npy_files)} npy files in {feat_root}")
12 | print("-" * 50)
13 |
14 | for npy_file in npy_files:
15 | file_name = osp.basename(npy_file)
16 | try:
17 | data = np.load(npy_file)
18 | print(f"{file_name}: shape={data.shape}")
19 | except Exception as e:
20 | print(f"{file_name}: Error loading file - {e}")
21 |
--------------------------------------------------------------------------------
/toolkit/utils/draw_process.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import os
3 | import os.path as osp
4 |
5 |
6 | def draw_loss(epoch, train_loss, valid_loss, save_path):
7 | plt.figure()
8 | plt.plot(epoch, train_loss, label="train loss")
9 | plt.plot(epoch, valid_loss, label="valid loss")
10 | plt.legend()
11 | plt.savefig(
12 | save_path,
13 | bbox_inches="tight",
14 | dpi=500,
15 | )
16 | plt.close()
17 | return
18 |
19 |
20 | def draw_metric(epoch, metric, key, save_path):
21 | plt.figure()
22 | plt.plot(epoch, metric, label="valid {}".format(key))
23 | plt.legend()
24 | plt.savefig(
25 | save_path,
26 | bbox_inches="tight",
27 | dpi=500,
28 | )
29 | plt.close()
30 |
--------------------------------------------------------------------------------
/toolkit/utils/functions.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | # config -> args [只把config里存在,但是args中不存在或者为None的部分赋值]
5 | def merge_args_config(args, config):
6 | args_dic = vars(args) # convert to map version
7 | for key in config:
8 | if key not in args_dic or args_dic[key] is None:
9 | args_dic[key] = config[key]
10 | args_new = argparse.Namespace(**args_dic) # change to namespace
11 | return args_new
12 |
13 |
14 | def func_update_storage(inputs, prefix, outputs):
15 | for key in inputs:
16 | val = inputs[key]
17 | # update key and value
18 | newkey = f"{prefix}_{key}"
19 | newval = val
20 | # store into outputs
21 | assert newkey not in outputs
22 | outputs[newkey] = newval
23 |
--------------------------------------------------------------------------------
/toolkit/utils/metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from sklearn.metrics import mean_squared_error
4 | from sklearn.metrics import f1_score, accuracy_score
5 |
6 |
7 | # 综合维度和离散的评价指标
8 | def overall_metric(emo_fscore, val_mse):
9 | final_score = emo_fscore - val_mse * 0.25
10 | return final_score
11 |
12 |
13 | # 只返回 metric 值,用于模型筛选
14 | def gain_metric_from_results(eval_results, metric_name="emoval"):
15 |
16 | if metric_name == "emoval":
17 | fscore = eval_results["emofscore"]
18 | valmse = eval_results["valmse"]
19 | overall = overall_metric(fscore, valmse)
20 | sort_metric = overall
21 | elif metric_name == "emo":
22 | fscore = eval_results["emofscore"]
23 | sort_metric = fscore
24 | elif metric_name == "val":
25 | valmse = eval_results["valmse"]
26 | sort_metric = -valmse
27 | elif metric_name == "loss":
28 | loss = eval_results["loss"]
29 | sort_metric = -loss
30 |
31 | return sort_metric
32 |
--------------------------------------------------------------------------------
/feature_extraction/run.sh:
--------------------------------------------------------------------------------
1 |
2 | project_dir="/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/"
3 |
4 | # --------------- 视觉预处理
5 | # --- 提取人脸
6 | # cd ./visual
7 | # python extract_openface.py --dataset=$project_dir --type="videoOne"
8 | # cd ../
9 | # #
10 | # # --- 提取视觉特征
11 | # cd ./visual
12 | # python -u extract_vision_huggingface.py --dataset=$project_dir --feature_level='UTTERANCE' --model_name='clip-vit-large-patch14' --gpu=0
13 | # cd ../
14 |
15 | # # --------------- 音频预处理
16 | # # --- 分离音频
17 | # cd ./audio
18 | # python split_audio.py --dataset=$project_dir
19 | # cd ../
20 | # #
21 | # # --- 提取音频特征
22 | # cd ./audio
23 | # python -u extract_audio_huggingface.py --dataset=$project_dir --feature_level='UTTERANCE' --model_name='chinese-hubert-large' --gpu=0
24 | # cd ../
25 |
26 | # # --------------- 文本预处理
27 | # #--- 获取文本
28 | # cd ./text
29 | # python split_asr.py --dataset=$project_dir
30 | # cd ../
31 | # #
32 | # # --- 提取文本特征
33 | cd ./text
34 | python extract_text_huggingface.py --dataset=$project_dir --feature_level='UTTERANCE' --model_name='bloom-7b1' --gpu=0
35 | cd ../
36 |
37 |
--------------------------------------------------------------------------------
/run_eval_unimodal.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_VISIBLE_DEVICES=0
3 | # 定义feat和seed的取值范围
4 | feats=("InternVL_2_5_HiCo_R16-UTT")
5 | seeds=(0 1 2 3 4)
6 |
7 | # 外层循环遍历feat
8 | for feat in "${feats[@]}"; do
9 | # 内层循环遍历seed
10 | for seed in "${seeds[@]}"; do
11 | echo "Running with seed=$seed, feat=$feat"
12 |
13 | # 执行命令
14 | python train.py --seed $seed \
15 | --dataset "seed${seed}" \
16 | --emo_rule "MER" \
17 | --save_model \
18 | --save_root "./saved/${feat}_3" \
19 | --feat "['${feat}', '${feat}', '${feat}']" \
20 | --lr 0.0001 \
21 | --batch_size 512 \
22 | --num_workers 4 \
23 | --epochs 80 \
24 |
25 | # 检查命令是否执行成功
26 | if [ $? -eq 0 ]; then
27 | echo "Command executed successfully for seed=$seed, feat=$feat"
28 | else
29 | echo "Error executing command for seed=$seed, feat=$feat"
30 | exit 1
31 | fi
32 | done
33 | done
--------------------------------------------------------------------------------
/feature_extraction/config.py:
--------------------------------------------------------------------------------
1 | # *_*coding:utf-8 *_*
2 | import os
3 | import sys
4 | import socket
5 | import os.path as osp
6 |
7 | ############ global path ##############
8 | # PATH_TO_PROJECT = (
9 | # "/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/"
10 | # )
11 | # PATH_TO_VIDEO = osp.join(PATH_TO_PROJECT, "video") # 文件的初始文件夹
12 |
13 | # PATH_TO_FEATURES = osp.join(PATH_TO_PROJECT, "features") # 剪辑片段的特征的保存文件夹
14 |
15 |
16 | # PATH_TO_RAW_FACE_Win = PATH_TO_VIDEO
17 | # PATH_TO_FEATURES_Win = PATH_TO_FEATURES
18 |
19 | ############ Models ##############
20 |
21 | # pre-trained models, including supervised and unsupervised
22 | PATH_TO_PRETRAINED_MODELS = "/sda/xyy/mer/MERTools/MER2024/tools"
23 | PATH_TO_OPENSMILE = "/sda/xyy/mer/MERTools/MER2024/tools/opensmile-2.3.0/"
24 | PATH_TO_FFMPEG = "/sda/xyy/mer/MERTools/MER2024/tools/ffmpeg-4.4.1-i686-static/ffmpeg"
25 | PATH_TO_WENET = (
26 | "/sda/xyy/mer/MERTools/MER2024/tools/wenet/wenetspeech_u2pp_conformer_libtorch"
27 | )
28 |
29 |
30 | PATH_TO_OPENFACE_Win = "/sda/xyy/mer/tools/OpenFace_2.2.0"
31 |
32 | PATH_TO_FFMPEG_Win = (
33 | "/sda/xyy/mer/MERTools/MER2024/tools/ffmpeg-3.4.1-win32-static/bin/ffmpeg"
34 | )
35 |
--------------------------------------------------------------------------------
/eval_avg_performance.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | import numpy as np
5 |
6 | if __name__ == "__main__":
7 | seeds = [0, 1, 2, 3, 4]
8 | feats = [
9 | "InternVL_2_5_HiCo_R16-UTT",
10 | ]
11 | # feats = [
12 | # "chinese-hubert-large-UTT",
13 | # "Qwen2-Audio-7B-UTT",
14 | # "chinese-hubert-base-UTT",
15 | # "whisper-large-v2-UTT",
16 | # "chinese-wav2vec2-large-UTT",
17 | # "chinese-wav2vec2-base-UTT",
18 | # "wavlm-base-UTT",
19 | # ]
20 | # feats = [
21 | # "chinese-roberta-wwm-ext-large-UTT",
22 | # "chinese-roberta-wwm-ext-UTT",
23 | # "chinese-macbert-large-UTT",
24 | # "chinese-macbert-base-UTT",
25 | # "bloom-7b1-UTT",
26 | # ]
27 |
28 | for feat in feats:
29 | avg_acc = 0
30 | avg_fscore = 0
31 | cnt = 0
32 | for seed in seeds:
33 | fdir = osp.join(
34 | "./saved", feat + "_3", "seed" + str(seed), "auto_attention", "run"
35 | )
36 |
37 | res = np.load(
38 | osp.join(fdir, "best_valid_results.npy"), allow_pickle=True
39 | ).item()
40 |
41 | avg_acc += res["eval_emoacc"]
42 | avg_fscore += res["eval_emofscore"]
43 | cnt += 1
44 |
45 | avg_acc /= cnt
46 | avg_fscore /= cnt
47 | print("### {}".format(feat + "_3"))
48 | print("acc={:.4f}, fscore={:.4f}".format(avg_acc, avg_fscore))
49 | print("")
50 |
51 | a = 1
52 |
--------------------------------------------------------------------------------
/feature_extraction/audio/split_audio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tqdm
3 | import glob
4 | import sys
5 |
6 | current_file_path = os.path.abspath(__file__)
7 | sys.path.append(os.path.dirname(os.path.dirname(current_file_path)))
8 | import config
9 | import argparse
10 |
11 |
12 | def func_split_audio_from_video_16k(video_root, save_root):
13 | if not os.path.exists(save_root):
14 | os.makedirs(save_root)
15 | for video_path in tqdm.tqdm(glob.glob(video_root + "/*")):
16 | videoname = os.path.basename(video_path)[:-4]
17 | audio_path = os.path.join(save_root, videoname + ".wav")
18 | if os.path.exists(audio_path):
19 | continue
20 | cmd = "%s -loglevel quiet -y -i %s -ar 16000 -ac 1 %s" % (
21 | config.PATH_TO_FFMPEG,
22 | video_path,
23 | audio_path,
24 | ) # linux
25 | # cmd = "%s -loglevel quiet -y -i %s -ar 16000 -ac 1 %s" %(config.PATH_TO_FFMPEG_Win, video_path, audio_path) # windows
26 | os.system(cmd)
27 |
28 |
29 | if __name__ == "__main__":
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument(
32 | "--dataset",
33 | type=str,
34 | default="/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/",
35 | help="file name",
36 | )
37 | args = parser.parse_args()
38 |
39 | dataset = args.dataset
40 |
41 | video_root = os.path.join(dataset, "video")
42 | save_root = os.path.join(dataset, "audio")
43 | func_split_audio_from_video_16k(video_root, save_root)
44 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: mertools
2 | channels:
3 | - pytorch
4 | - defaults
5 | - anaconda
6 | dependencies:
7 | - python=3.9
8 | - cudatoolkit
9 | - pip
10 | - pytorch=1.12.1
11 | - pytorch-mutex=1.0=cuda
12 | - torchaudio=0.12.1
13 | - torchvision=0.13.1
14 |
15 | - pip:
16 | - accelerate==0.16.0
17 | - aiohttp==3.8.4
18 | - aiosignal==1.3.1
19 | - async-timeout==4.0.2
20 | - attrs==22.2.0
21 | - bitsandbytes==0.37.0
22 | - cchardet==2.1.7
23 | - chardet==5.1.0
24 | - contourpy==1.0.7
25 | - cycler==0.11.0
26 | - filelock==3.9.0
27 | - fonttools==4.38.0
28 | - frozenlist==1.3.3
29 | - huggingface-hub==0.13.4
30 | - importlib-resources==5.12.0
31 | - kiwisolver==1.4.4
32 | - matplotlib==3.7.0
33 | - multidict==6.0.4
34 | - openai==0.27.0
35 | - packaging==23.0
36 | - psutil==5.9.4
37 | - pycocotools==2.0.6
38 | - pyparsing==3.0.9
39 | - python-dateutil==2.8.2
40 | - pyyaml==6.0
41 | - regex==2022.10.31
42 | - tokenizers==0.13.2
43 | - tqdm==4.64.1
44 | - transformers==4.28.0
45 | - timm==0.6.13
46 | - spacy==3.5.1
47 | - webdataset==0.2.48
48 | - scikit-learn==1.2.2
49 | - scipy==1.10.1
50 | - yarl==1.8.2
51 | - zipp==3.14.0
52 | - omegaconf==2.3.0
53 | - opencv-python==4.7.0.72
54 | - iopath==0.1.10
55 | - decord==0.6.0
56 | - tenacity==8.2.2
57 | - peft
58 | - pycocoevalcap
59 | - sentence-transformers
60 | - umap-learn
61 | - notebook
62 | - gradio==3.24.1
63 | - gradio-client==0.0.8
64 | - wandb
65 | - einops
66 | - SentencePiece
67 | - ftfy
68 | - thop
69 | - pytorchvideo==0.1.5
--------------------------------------------------------------------------------
/feature_extraction/visual/split_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import cv2
4 | from moviepy.editor import VideoFileClip
5 | import argparse
6 | import sys
7 |
8 | current_file_path = os.path.abspath(__file__)
9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(current_file_path))))
10 | import config
11 |
12 | if __name__ == "__main__":
13 | split_t = config.VIDEO_SPLIT_T # 分割的间断时长
14 | duration_t = config.VIDEO_DURATION_T # 每段分割的持续时长
15 |
16 | # ----------- 需要指定的路径
17 | video_dir = config.PATH_TO_VIDEO # 文件的初始文件夹
18 | save_basedir = config.PATH_TO_VIDEO_CLIP # 剪辑片段的保存文件夹
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument(
22 | "--dataset",
23 | type=str,
24 | default="test",
25 | help="file name",
26 | )
27 | args = parser.parse_args()
28 |
29 | ftitle = args.dataset
30 |
31 | # 视频文件的路径
32 | video_path = osp.join(video_dir, ftitle + ".mp4")
33 |
34 | # 保存路径
35 | output_folder = osp.join(save_basedir, ftitle)
36 | if not osp.exists(output_folder):
37 | os.makedirs(output_folder)
38 |
39 | # 加载视频
40 | clip = VideoFileClip(video_path)
41 |
42 | # 视频总时长
43 | duration = clip.duration
44 |
45 | # 分割视频的起始时间
46 | start_time = 0
47 |
48 | # 循环分割视频,直到视频结束
49 | cnt = 0
50 | for i in range(0, int(duration // split_t)):
51 | # 计算结束时间,确保不超过视频总时长
52 | end_time = min(start_time + duration_t, duration)
53 |
54 | # 创建子剪辑
55 | subclip = clip.subclip(start_time, end_time)
56 |
57 | # 生成输出文件名
58 | output_filename = os.path.join(output_folder, f"clip_{cnt}.mp4")
59 |
60 | # 保存子剪辑
61 | subclip.write_videofile(output_filename, codec="libx264")
62 |
63 | # 更新起始时间
64 | start_time += split_t
65 | print(f"{cnt} done!")
66 |
67 | cnt += 1
68 |
69 | # 释放资源
70 | clip.close()
71 |
72 | print("视频分割完成")
73 |
--------------------------------------------------------------------------------
/models/modules/encoder.py:
--------------------------------------------------------------------------------
1 | """
2 | Ref paper: Tensor Fusion Network for Multimodal Sentiment Analysis
3 | Ref url: https://github.com/Justin1904/TensorFusionNetworks
4 | """
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | ## 这两个模块都是用在 TFN 中的 (video|audio)
9 | class MLPEncoder(nn.Module):
10 | '''
11 | The subnetwork that is used in TFN for video and audio in the pre-fusion stage
12 | '''
13 |
14 | def __init__(self, in_size, hidden_size, dropout):
15 | '''
16 | Args:
17 | in_size: input dimension
18 | hidden_size: hidden layer dimension
19 | dropout: dropout probability
20 | Output:
21 | (return value in forward) a tensor of shape (batch_size, hidden_size)
22 | '''
23 | super(MLPEncoder, self).__init__()
24 | # self.norm = nn.BatchNorm1d(in_size)
25 | self.drop = nn.Dropout(p=dropout)
26 | self.linear_1 = nn.Linear(in_size, hidden_size)
27 | self.linear_2 = nn.Linear(hidden_size, hidden_size)
28 | self.linear_3 = nn.Linear(hidden_size, hidden_size)
29 |
30 | def forward(self, x):
31 | '''
32 | Args:
33 | x: tensor of shape (batch_size, in_size)
34 | '''
35 | # normed = self.norm(x)
36 | dropped = self.drop(x)
37 | y_1 = F.relu(self.linear_1(dropped))
38 | y_2 = F.relu(self.linear_2(y_1))
39 | y_3 = F.relu(self.linear_3(y_2))
40 |
41 | return y_3
42 |
43 |
44 | # TFN 中的文本编码,额外需要lstm 操作 [感觉是audio|video]
45 | class LSTMEncoder(nn.Module):
46 | '''
47 | The LSTM-based subnetwork that is used in TFN for text
48 | '''
49 |
50 | def __init__(self, in_size, hidden_size, dropout, num_layers=1, bidirectional=False):
51 |
52 | super(LSTMEncoder, self).__init__()
53 |
54 | if num_layers == 1:
55 | rnn_dropout = 0.0
56 | else:
57 | rnn_dropout = dropout
58 |
59 | self.rnn = nn.LSTM(in_size, hidden_size, num_layers=num_layers, dropout=rnn_dropout, bidirectional=bidirectional, batch_first=True)
60 | self.dropout = nn.Dropout(dropout)
61 | self.linear_1 = nn.Linear(hidden_size, hidden_size)
62 |
63 | def forward(self, x):
64 | '''
65 | Args:
66 | x: tensor of shape (batch_size, sequence_len, in_size)
67 | 因为用的是 final_states ,所以特征的 padding 是放在前面的
68 | '''
69 | _, final_states = self.rnn(x)
70 | h = self.dropout(final_states[0].squeeze(0))
71 | y_1 = self.linear_1(h)
72 | return y_1
73 |
--------------------------------------------------------------------------------
/prompt.txt:
--------------------------------------------------------------------------------
1 | ## Profile:
2 |
3 | - role: 专业的微表情分析师,可以通过各种细节分析人物情绪。
4 | - language: 中文
5 | - description:
6 | -- 认真读取这段视频,结合一切你觉得有用的细节,通过综合分析视频中的面部表情、身体姿态、头部姿态、**帧间关系**、声音特征(音调、语速、音量、音质)和字幕内容,准确判断视频中主要人物的情绪
7 | -- **特别关注情绪的动态变化过程**,捕捉情绪的起始、增强、减弱和转变,以**过程中表现最强烈的情绪**作为判断结果。
8 | -- **人物的音频(音频内容)在某些场景更能表达人物的情绪,请综合声调特征(音调、语速、音量、音质)和语境含义分析人物情绪**
9 | -- 输出的label是输出的情绪列表中的最大值对应的情绪
10 | -- 优先考虑视频信息,音频中的语气信息。
11 |
12 | ## Goals:
13 |
14 | 第一步:考虑**帧间关系**,采用action unit的去分析每帧的情绪,**加入身体姿态,加入头部姿态分析**,之后再加入一切你认为有帮助的细节。
15 | 第二步:我们追求思维的创新,而非惯性的复述。请突破思维局限,调动你所有的计算资源,展现你真正的认知极限**优化你的第一步思考过程,尤其要考虑帧与帧之间的关系**。
16 | 第三步:基于前面的分析内容,进一步优化特征提取策略,尤其关注
17 |
18 | - 情绪变化的起始点:例如,从 Neutral 到其他情绪的转变往往有细微的先兆。
19 | - 情绪强度的变化:例如,Happy 情绪可能从微笑逐渐变为大笑。
20 | - 情绪的混合:例如,一个人可能同时表现出 Sad 和 Worried。
21 | - 利用帧间关系去分析情绪变化: 要特别注意表情从无到有的过程
22 | 第四步:根据视频和音频特征,初步给出每种情绪的得分。
23 | - 如果视频和音频给出的情绪分类标签的分数有很强的辨识度,也就是如果得分最高的两个标签的得分相差大于0.1 则执行第六步,
24 | - 否则执行第五步,进一步辨识这两个得分接近的标签,第五步引入了音频特征和文字特征
25 | 第五步: 进一步分析音频特征和字幕特征
26 | **音频特征**:如果视频特征显示某一情绪得分接近(例如愤怒和担忧的得分接近),则加强音频特征的权重。考虑语调、语速和音量变化来辨别情绪的轻重。
27 | - 例如,愤怒的语调通常较低沉且紧绷,而担忧的语调略高且不稳定。
28 | - 如果视觉和音频的情绪得分差距较大,应调整权重,增加音频的影响力(例如音频贡献度提高到40%-50%)。
29 | **字幕特征**:根据字幕内容,尤其是表达情绪意图的词语(如"焦虑"、"愤怒")来辅助判断。如果视频和音频的判断得分相近,可以进一步提升字幕对情绪判断的贡献,帮助消除歧义。
30 | - 注意避免文字的误导性判断。例如,"你在干什么?" 可能是疑问;但是语气很愤怒时,也可以表示愤怒;表情开心时,可以解读为惊喜。
31 | 第六步:总结你的所有依据,给出最终结论,输出的推理过程应该保持简洁(极其重要)。
32 |
33 |
34 | ## Constraints:
35 |
36 | - 请用以下规范输出:
37 | - 情绪只能为["neutral", "angry", "happy", "sad", "worried", "surprise"]中的一个;
38 | - 对于视频中存在多个人物的情况,只需要分析出镜时间最长的主要人物的情绪;
39 | - 确保输出的json格式的正确性;
40 | - 语言平实、具体、简明、准确;
41 | - 优先选择具体名词替代抽象概念;
42 | - 保持段落简明(不超过5行);
43 | - 禁用文学化修辞;
44 | - 重点信息前置;
45 | - 复杂内容分点说明;
46 | - 保持口语化但不过度简化专业内容;
47 | - 确保信息准确前提下优先选择大众认知词汇;
48 | - **worried 的优先级相对较低**。仅在视频中有明显 Worried 表现且其他情绪不明显时预测为 Worried。
49 | - Sad/Surprise/Neutral 优先级:在这些情绪与 Worried 接近时,优先考虑 Sad/Surprise/Neutral。
50 |
51 | ## OutputFormat:
52 |
53 | {
54 | "推理过程": "主人公的面部表情和语气都很愉快,所以我认为他是开心的, 但是他的眼神有点忧郁,所以我认为他也有点伤心, 但是他的语气很平静,所以我认为他不是很生气",
55 | "情绪": {
56 | "neutral": 0.0,
57 | "angry": 0.0,
58 | "happy": 0.0,
59 | "sad": 0.0,
60 | "worried": 0.0,
61 | "surprise": 0.0
62 | },
63 | "label": "happy",
64 | "贡献度":{
65 | "视频": "0%",
66 | "音频": "0%",
67 | "字幕": "0%"
68 | }
69 |
70 | }
71 |
72 |
73 | ## Example:
74 |
75 | {
76 | "推理过程": "主人公的面部表情和语气都很愉快,所以我认为他是开心的, 但是他的眼神有点忧郁,所以我认为他也有点伤心, 但是他的语气很平静,所以我认为他不是很生气",
77 | "情绪": {
78 | "neutral": 0.3,
79 | "angry": 0.01,
80 | "happy": 0.8,
81 | "sad": 0.01,
82 | "worried": 0.1,
83 | "surprise": 0.4
84 | },
85 | "label": "happy",
86 | "贡献度":{
87 | "视频": "40%",
88 | "音频": "30%",
89 | "字幕": "30%"
90 | }
91 | }
--------------------------------------------------------------------------------
/dataloader/dataloader.py:
--------------------------------------------------------------------------------
1 | """
2 | Description: build [train / valid / test] dataloader
3 | Author: Xiongjun Guan
4 | Date: 2023-03-01 19:41:05
5 | version: 0.0.1
6 | LastEditors: Xiongjun Guan
7 | LastEditTime: 2023-03-20 21:01:47
8 |
9 | Copyright (C) 2023 by Xiongjun Guan, Tsinghua University. All rights reserved.
10 | """
11 |
12 | import copy
13 | import logging
14 | import os.path as osp
15 | import random
16 |
17 | import cv2
18 | import numpy as np
19 | from scipy import signal
20 | from torch.utils.data import DataLoader, Dataset
21 |
22 |
23 | class load_dataset_train(Dataset):
24 |
25 | def __init__(
26 | self,
27 | names,
28 | emo_labels,
29 | val_labels,
30 | emo_rule,
31 | feat_roots,
32 | ):
33 | self.names = names
34 | self.emo_labels = emo_labels
35 | self.val_labels = val_labels
36 | self.feat_roots = feat_roots
37 |
38 | self.emo2idx_mer, self.idx2emo_mer = {}, {}
39 | for ii, emo in enumerate(emo_rule):
40 | self.emo2idx_mer[emo] = ii
41 | for ii, emo in enumerate(emo_rule):
42 | self.idx2emo_mer[ii] = emo
43 |
44 | def __len__(self):
45 | return len(self.names)
46 |
47 | def __getitem__(self, idx):
48 | emo = self.emo2idx_mer[self.emo_labels[idx]]
49 | val = self.val_labels[idx]
50 | name = self.names[idx]
51 |
52 | inputs = {}
53 | idx = 0
54 | for feat_root in self.feat_roots:
55 | ftitle = osp.basename(name).split(".")[0]
56 | inputs["feat{}".format(idx)] = (
57 | np.load(osp.join(feat_root, ftitle + ".npy"))
58 | .squeeze()
59 | .astype(np.float32)
60 | )
61 | idx += 1
62 |
63 | return inputs, emo, val, name
64 |
65 |
66 | def get_dataloader_train(
67 | names,
68 | emo_labels,
69 | val_labels,
70 | emo_rule,
71 | feat_roots,
72 | batch_size=1,
73 | num_workers=4,
74 | shuffle=True,
75 | ):
76 | # Create dataset
77 | try:
78 | dataset = load_dataset_train(
79 | names=names,
80 | emo_labels=emo_labels,
81 | val_labels=val_labels,
82 | emo_rule=emo_rule,
83 | feat_roots=feat_roots,
84 | )
85 | except Exception as e:
86 | logging.error("Error in DataLoader: ", repr(e))
87 | return
88 |
89 | train_loader = DataLoader(
90 | dataset,
91 | batch_size=batch_size,
92 | shuffle=shuffle,
93 | num_workers=num_workers,
94 | pin_memory=True,
95 | )
96 | logging.info(f"n_train:{len(dataset)}")
97 |
98 | return train_loader
99 |
100 |
101 | def get_dataloader_valid(
102 | names,
103 | emo_labels,
104 | val_labels,
105 | emo_rule,
106 | feat_roots,
107 | batch_size=1,
108 | num_workers=4,
109 | shuffle=False,
110 | ):
111 | # Create dataset
112 | try:
113 | dataset = load_dataset_train(
114 | names=names,
115 | emo_labels=emo_labels,
116 | val_labels=val_labels,
117 | emo_rule=emo_rule,
118 | feat_roots=feat_roots,
119 | )
120 | except Exception as e:
121 | logging.error("Error in DataLoader: ", repr(e))
122 | return
123 |
124 | valid_loader = DataLoader(
125 | dataset,
126 | batch_size=batch_size,
127 | shuffle=shuffle,
128 | num_workers=num_workers,
129 | pin_memory=True,
130 | )
131 | logging.info(f"n_valid:{len(dataset)}")
132 |
133 | return valid_loader
134 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # More Is Better: A MoE-Based Emotion Recognition Framework with Human Preference Alignment (MRAC @ ACM-MM 2025)
2 |
3 | [](https://opensource.org/licenses/Apache-2.0)
4 | [](https://arxiv.org/abs/2508.06036)
5 | [](https://github.com/zhuyjan/MER2025-MRAC25/stargazers)
6 |
7 | This repository provides our runner-up solution for [MER2025-SEMI challenge](https://zeroqiaoba.github.io/MER2025-website/) at MRAC'25 workshop. If our project helps you, please give us a star ⭐ on GitHub to support us. 🙏🙏
8 |
9 | ## Overview
10 |

11 |
12 | We propose a comprehensive framework, grounded in the principle that "more is better," to construct a robust Mixture of Experts (MoE) emotion recognition system. Our approach integrates a diverse range of input modalities as independent experts, including novel signals such as knowledge from large Vision-Language Models (VLMs) and temporal Action Unit (AU) information.
13 |
14 | ## Setup
15 | Our environment setup is identical to that of [MERBench](https://github.com/zeroQiaoba/MERTools/tree/master/MERBench).
16 | ```bash
17 | conda env create -f environment.yml
18 | ```
19 |
20 | ## Workflow
21 |
22 | **Preprocessing**
23 | 1. Extract features for each sample. Each sample should correspond to a `.npy` file with shape `(C,)`.
24 | You can refer to [MER2025 Track1](https://github.com/zeroQiaoba/MERTools/tree/master/MER2025/MER2025_Track1) for more details on feature extraction.
25 | (We also provide our Gemini prompt for reference — see `prompt.txt` for details.)
26 | 2. Run `./preprocess/check_feat_dim.py` to obtain the dimensionality of each type of feature.
27 | 3. Fill in the feature names and their corresponding dimensions in the modality-specific dictionaries in `./config.py`.
28 | 4. Run `./preprocess/mer2025_base.py` to split the dataset into training and validation sets.
29 |
30 | **Training**
31 | - Run `./train.py` to perform training and validation.
32 | - Run `./run_eval_unimodal.sh` to evaluate unimodal performance. (All three branches use the same input for fair comparison with the trimodal setting.)
33 |
34 | ## Training Information Format
35 |
36 | Saved as `.npy` files under `./train_lst/`, each containing a Python dictionary with the following structure:
37 |
38 | ```python
39 | train_info = {
40 | # List of video names, e.g., ["example1", ...]
41 | "names": split_train_names,
42 | # List of emotion labels, e.g., ["happy", ...]
43 | "emos": split_train_emos,
44 | # List of emotion valence values, e.g., [-10, ...]
45 | "vals": split_train_vals,
46 | }
47 |
48 | valid_info = {
49 | # List of video names, e.g., ["example1", ...]
50 | "names": split_valid_names,
51 | # List of emotion labels, e.g., ["happy", ...]
52 | "emos": split_valid_emos,
53 | # List of emotion valence values, e.g., [-10, ...]
54 | "vals": split_valid_vals,
55 | }
56 | ```
57 |
58 | ## Citation
59 | If you find this work useful for your research, please give us a star and use the following BibTeX entry for citation.
60 | ```
61 | @inproceedings{xie2025more,
62 | title={More is better: A moe-based emotion recognition framework with human preference alignment},
63 | author={Xie, Jun and Zhu, Yingjian and Chen, Feng and Zhang, Zhenghao and Fan, Xiaohui and Yi, Hongzhu and Wang, Xinming and Yu, Chen and Bi, Yue and Zhao, Zhaoran and others},
64 | booktitle={Proceedings of the 3rd International Workshop on Multimodal and Responsible Affective Computing},
65 | pages={2--7},
66 | year={2025}
67 | }
68 | ```
69 |
--------------------------------------------------------------------------------
/feature_extraction/visual/util.py:
--------------------------------------------------------------------------------
1 | # *_*coding:utf-8 *_*
2 | import os
3 | import re
4 | import pandas as pd
5 | import numpy as np
6 | import struct
7 |
8 |
9 | ## for OPENFACE
10 | ## reference: https://gist.github.com/btlorch/6d259bfe6b753a7a88490c0607f07ff8
11 | def read_hog(filename, batch_size=5000):
12 | """
13 | Read HoG features file created by OpenFace.
14 | For each frame, OpenFace extracts 12 * 12 * 31 HoG features, i.e., num_features = 4464. These features are stored in row-major order.
15 | :param filename: path to .hog file created by OpenFace
16 | :param batch_size: how many rows to read at a time
17 | :return: is_valid, hog_features
18 | is_valid: ndarray of shape [num_frames]
19 | hog_features: ndarray of shape [num_frames, num_features]
20 | """
21 | all_feature_vectors = []
22 | with open(filename, "rb") as f:
23 | (num_cols,) = struct.unpack("i", f.read(4)) # 12
24 | (num_rows,) = struct.unpack("i", f.read(4)) # 12
25 | (num_channels,) = struct.unpack("i", f.read(4)) # 31
26 |
27 | # The first four bytes encode a boolean value whether the frame is valid
28 | num_features = 1 + num_rows * num_cols * num_channels
29 | feature_vector = struct.unpack(
30 | "{}f".format(num_features), f.read(num_features * 4)
31 | )
32 | feature_vector = np.array(feature_vector).reshape(
33 | (1, num_features)
34 | ) # [1, 4464+1]
35 | all_feature_vectors.append(feature_vector)
36 |
37 | # Every frame contains a header of four float values: num_cols, num_rows, num_channels, is_valid
38 | num_floats_per_feature_vector = 4 + num_rows * num_cols * num_channels
39 | # Read in batches of given batch_size
40 | num_floats_to_read = num_floats_per_feature_vector * batch_size
41 | # Multiply by 4 because of float32
42 | num_bytes_to_read = num_floats_to_read * 4
43 |
44 | while True:
45 | bytes = f.read(num_bytes_to_read)
46 | # For comparison how many bytes were actually read
47 | num_bytes_read = len(bytes)
48 | assert (
49 | num_bytes_read % 4 == 0
50 | ), "Number of bytes read does not match with float size"
51 | num_floats_read = num_bytes_read // 4
52 | assert (
53 | num_floats_read % num_floats_per_feature_vector == 0
54 | ), "Number of bytes read does not match with feature vector size"
55 | num_feature_vectors_read = num_floats_read // num_floats_per_feature_vector
56 |
57 | feature_vectors = struct.unpack("{}f".format(num_floats_read), bytes)
58 | # Convert to array
59 | feature_vectors = np.array(feature_vectors).reshape(
60 | (num_feature_vectors_read, num_floats_per_feature_vector)
61 | )
62 | # Discard the first three values in each row (num_cols, num_rows, num_channels)
63 | feature_vectors = feature_vectors[:, 3:]
64 | # Append to list of all feature vectors that have been read so far
65 | all_feature_vectors.append(feature_vectors)
66 |
67 | if num_bytes_read < num_bytes_to_read:
68 | break
69 |
70 | # Concatenate batches
71 | all_feature_vectors = np.concatenate(all_feature_vectors, axis=0)
72 |
73 | # Split into is-valid and feature vectors
74 | is_valid = all_feature_vectors[:, 0]
75 | feature_vectors = all_feature_vectors[:, 1:]
76 |
77 | return is_valid, feature_vectors
78 |
79 |
80 | ## for OPENFACE
81 | def read_csv(filename, startIdx):
82 | data = pd.read_csv(filename)
83 | all_feature_vectors = []
84 | for index in data.index:
85 | features = np.array(data.iloc[index][startIdx:])
86 | all_feature_vectors.append(features)
87 | all_feature_vectors = np.array(all_feature_vectors)
88 | return all_feature_vectors
89 |
--------------------------------------------------------------------------------
/preprocess/mer2025_base.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import shutil
4 | import sys
5 |
6 | import pandas as pd
7 |
8 | sys.path.append("/mnt/public/gxj_2/EmoNet_Pro/")
9 | import argparse
10 | import os.path as osp
11 | import random
12 | from collections import Counter
13 |
14 | import numpy as np
15 | import torch
16 | from tqdm import tqdm
17 |
18 | import config
19 |
20 |
21 | def set_seed(seed=0):
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed(seed)
24 | torch.cuda.manual_seed_all(seed)
25 | np.random.seed(seed)
26 | random.seed(seed)
27 | torch.backends.cudnn.benchmark = False
28 | torch.backends.cudnn.deterministic = True
29 |
30 |
31 | # run -d toolkit/preprocess/mer2024.py
32 | if __name__ == "__main__":
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument(
35 | "--seed",
36 | type=int,
37 | default=0,
38 | help="random seed",
39 | )
40 | args = parser.parse_args()
41 |
42 | random_seed = args.seed
43 | set_seed(random_seed)
44 | print("### random seed: {}".format(random_seed))
45 |
46 | # --- data config
47 | emo_rule = "MER"
48 | gt_path = "/mnt/public/share/data/MER2025/mer2025-dataset/track1_train_disdim.csv"
49 |
50 | # --- save config
51 | save_root = "/mnt/public/gxj_2/EmoNet_Pro/lst_train/mer25_train_val"
52 | save_name = "seed{}".format(random_seed)
53 |
54 | # --- load data and format
55 | mapping_rules = config.EMO_RULE[emo_rule]
56 |
57 | all_names = []
58 | all_emos = []
59 | all_vals = []
60 | df = pd.read_csv(gt_path)
61 | for index, row in tqdm(df.iterrows(), total=len(df)):
62 | name = row["name"]
63 | discrete = row["discrete"]
64 | valence = row["valence"]
65 |
66 | # emo = mapping_rules.index(discrete)
67 |
68 | all_names.append(name)
69 | all_emos.append(discrete)
70 | all_vals.append(valence)
71 |
72 | counted_numbers = {}
73 | for emo in all_emos:
74 | if emo in counted_numbers:
75 | counted_numbers[emo] += 1
76 | else:
77 | counted_numbers[emo] = 1
78 | for emo in mapping_rules:
79 | print(
80 | "{}, num={}, percent={:.2f}%".format(
81 | emo,
82 | counted_numbers[emo],
83 | 100 * counted_numbers[emo] / len(all_emos),
84 | )
85 | )
86 |
87 | # --------------- split train & test ---------------
88 | split_ratio = 0.2
89 | whole_num = len(all_names)
90 |
91 | # gain indices for cross-validation
92 | indices = np.arange(whole_num)
93 | random.shuffle(indices)
94 |
95 | # split indices into 1-fold
96 | each_folder_num = int(whole_num * split_ratio)
97 | valid_idxs = indices[0:each_folder_num]
98 | train_idxs = indices[each_folder_num:]
99 |
100 | split_train_names = []
101 | split_train_emos = []
102 | split_train_vals = []
103 | for idx in train_idxs:
104 | split_train_names.append(all_names[idx])
105 | split_train_emos.append(all_emos[idx])
106 | split_train_vals.append(all_vals[idx])
107 |
108 | split_valid_names = []
109 | split_valid_emos = []
110 | split_valid_vals = []
111 | for idx in valid_idxs:
112 | split_valid_names.append(all_names[idx])
113 | split_valid_emos.append(all_emos[idx])
114 | split_valid_vals.append(all_vals[idx])
115 |
116 | train_info = {
117 | "names": split_train_names,
118 | "emos": split_train_emos,
119 | "vals": split_train_vals,
120 | }
121 | valid_info = {
122 | "names": split_valid_names,
123 | "emos": split_valid_emos,
124 | "vals": split_valid_vals,
125 | }
126 |
127 | print("----------- summary")
128 | print(
129 | "split_train: {}\nsplit_valid: {}".format(
130 | len(train_info["names"]), len(valid_info["names"])
131 | )
132 | )
133 |
134 | save_path = os.path.join(save_root, save_name + ".npy")
135 | if not os.path.exists(save_root):
136 | os.makedirs(save_root)
137 |
138 | np.save(
139 | save_path,
140 | {
141 | "train_info": train_info,
142 | "valid_info": valid_info,
143 | },
144 | )
145 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # ----------------- path to train/test info
2 | PATH_TO_TRAIN_LST = "/mnt/public/gxj_2/EmoNet_Pro/lst_train/mer25_train_val"
3 | PATH_TO_TEST_LST = "/mnt/public/gxj_2/EmoNet_Pro/lst_test/"
4 |
5 | # ----------------- emo / index matching rules
6 | EMO_RULE = {
7 | "MER": ["neutral", "angry", "happy", "sad", "worried", "surprise"],
8 | "CREMA-D": ["neutral", "angry", "happy", "sad", "fear", "disgust"],
9 | "TESS": ["neutral", "angry", "happy", "sad", "fear", "disgust"],
10 | "RAVDESS": [
11 | "neutral",
12 | "angry",
13 | "happy",
14 | "sad",
15 | "fear",
16 | "disgust",
17 | "surprised",
18 | "calm",
19 | ],
20 | }
21 |
22 | # ----------------- features can be used
23 | FEAT_VIDEO_DICT = {
24 | "senet50face_UTT": [
25 | 512,
26 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/senet50face_UTT",
27 | ],
28 | "resnet50face_UTT": [
29 | 512,
30 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/resnet50face_UTT",
31 | ],
32 | "clip-vit-large-patch14-UTT": [
33 | 768,
34 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/clip-vit-large-patch14-UTT",
35 | ],
36 | "clip-vit-base-patch32-UTT": [
37 | 512,
38 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/clip-vit-base-patch32-UTT",
39 | ],
40 | "videomae-large-UTT": [
41 | 1024,
42 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/videomae-large-UTT",
43 | ],
44 | "videomae-base-UTT": [
45 | 768,
46 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/videomae-base-UTT",
47 | ],
48 | "manet_UTT": [
49 | 1024,
50 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/manet_UTT",
51 | ],
52 | "emonet_UTT": [
53 | 256,
54 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/emonet_UTT",
55 | ],
56 | "dinov2-large-UTT": [
57 | 1024,
58 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/dinov2-large-UTT",
59 | ],
60 | "InternVL_2_5_HiCo_R16-UTT": [
61 | 4096,
62 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/InternVL_2_5_HiCo_R16-UTT",
63 | ],
64 | }
65 |
66 | FEAT_AUDIO_DICT = {
67 | "chinese-hubert-large-UTT": [
68 | 1024,
69 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-hubert-large-UTT",
70 | ],
71 | "Qwen2-Audio-7B-UTT": [
72 | 1280,
73 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/Qwen2-Audio-7B-UTT",
74 | ],
75 | "chinese-hubert-base-UTT": [
76 | 768,
77 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-hubert-base-UTT",
78 | ],
79 | "whisper-large-v2-UTT": [
80 | 1280,
81 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/whisper-large-v2-UTT",
82 | ],
83 | "chinese-wav2vec2-large-UTT": [
84 | 1024,
85 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-wav2vec2-large-UTT",
86 | ],
87 | "chinese-wav2vec2-base-UTT": [
88 | 768,
89 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-wav2vec2-base-UTT",
90 | ],
91 | "wavlm-base-UTT": [
92 | 768,
93 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/wavlm-base-UTT",
94 | ],
95 | }
96 |
97 | FEAT_TEXT_DICT = {
98 | "chinese-roberta-wwm-ext-large-UTT": [
99 | 1024,
100 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-roberta-wwm-ext-large-UTT",
101 | ],
102 | "chinese-roberta-wwm-ext-UTT": [
103 | 768,
104 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-roberta-wwm-ext-UTT",
105 | ],
106 | "chinese-macbert-large-UTT": [
107 | 1024,
108 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-macbert-large-UTT",
109 | ],
110 | "chinese-macbert-base-UTT": [
111 | 768,
112 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/chinese-macbert-base-UTT",
113 | ],
114 | "bloom-7b1-UTT": [
115 | 4096,
116 | "/mnt/public/share/data/MER2025/mer2025-dataset-process/features/bloom-7b1-UTT",
117 | ],
118 | }
119 |
120 | MODEL_DIR_DICT = {}
121 |
--------------------------------------------------------------------------------
/preprocess/mer2024_base.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import shutil
4 | import sys
5 |
6 | sys.path.append("/mnt/public/gxj_2/EmoNet_Pro/")
7 | import argparse
8 | import os.path as osp
9 | import random
10 |
11 | import numpy as np
12 | import torch
13 |
14 | import config
15 | # from toolkit.utils.read_files import *
16 | from toolkit.utils.read_files import func_read_key_from_csv
17 |
18 |
19 | def set_seed(seed=0):
20 | torch.manual_seed(seed)
21 | torch.cuda.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 | np.random.seed(seed)
24 | random.seed(seed)
25 | torch.backends.cudnn.benchmark = False
26 | torch.backends.cudnn.deterministic = True
27 |
28 |
29 | def read_data_into_lists(filename, emo_rule):
30 | sample_ids = []
31 | numbers = []
32 | emotions = []
33 | float_values = []
34 |
35 | with open(filename, "r") as file:
36 | for line in file:
37 | parts = line.strip().split()
38 |
39 |
40 | if parts[2] not in config.EMO_RULE[emo_rule]:
41 | continue
42 |
43 | sample_ids.append(parts[0])
44 | numbers.append(int(parts[1]))
45 | emotions.append(parts[2])
46 | float_values.append(float(parts[3]))
47 |
48 | return sample_ids, numbers, emotions, float_values
49 |
50 |
51 | def normalize_dataset_format(
52 | emo_rule,
53 | gt_path,
54 | video_root,
55 | feat_roots,
56 | ):
57 |
58 | all_names, _, all_emos, all_vals = read_data_into_lists(
59 | gt_path, emo_rule
60 | )
61 |
62 | # --------------- split train & test ---------------
63 | split_ratio = 0.2
64 | whole_num = len(all_names)
65 |
66 | # gain indices for cross-validation
67 | indices = np.arange(whole_num)
68 | random.shuffle(indices)
69 |
70 | # split indices into 1-fold
71 | each_folder_num = int(whole_num * split_ratio)
72 | test_idxs = indices[0:each_folder_num]
73 | train_idxs = indices[each_folder_num:]
74 |
75 | split_train_names = []
76 | split_train_emos = []
77 | split_train_vals = []
78 | for idx in train_idxs:
79 | name = all_names[idx]
80 | train_name = osp.join(video_root, "video", name)
81 | split_train_names.append(train_name)
82 | split_train_emos.append(all_emos[idx])
83 | split_train_vals.append(all_vals[idx])
84 |
85 |
86 | feat_dims = []
87 | for feat_root in feat_roots:
88 | feat_dim = np.load(osp.join(feat_root,all_names[0]+".npy")).shape[0]
89 | feat_dims.append(feat_dim)
90 |
91 | split_test_names = []
92 | split_test_emos = []
93 | split_test_vals = []
94 | for idx in test_idxs:
95 | name = all_names[idx]
96 | test_name = osp.join(video_root, name)
97 | split_test_names.append(test_name)
98 | split_test_emos.append(all_emos[idx])
99 | split_test_vals.append(all_vals[idx])
100 |
101 | cnt=1
102 | for dim,root in zip(feat_dims,feat_roots):
103 | print("----------- feat {}".format(cnt))
104 | print("# root={}".format(root))
105 | print("# dim={}".format(dim))
106 | cnt+=1
107 |
108 | train_info = {
109 | "names": split_train_names,
110 | "emos": split_train_emos,
111 | "vals": split_train_vals,
112 | }
113 | valid_info = {
114 | "names": split_test_names,
115 | "emos": split_test_emos,
116 | "vals": split_test_vals,
117 | }
118 |
119 | return {
120 | "feat_dims": feat_dims,
121 | "feat_roots": feat_roots,
122 | "train_info": train_info,
123 | "valid_info": valid_info,
124 | }
125 |
126 |
127 | # run -d toolkit/preprocess/mer2024.py
128 | if __name__ == "__main__":
129 | random_seed = 1
130 | set_seed(random_seed)
131 | print("random seed: {}".format(random_seed))
132 |
133 | # --- data config
134 | emo_rule = "MER"
135 | gt_path = "/mnt/public/share/data/Dataset/MER2024/MER2024-labeled.txt"
136 | video_root = "/mnt/public/share/data/Dataset/MER2024/video/"
137 |
138 | # --- feature config
139 | description = "mer24-train"
140 | feat_roots = ["/mnt/public/share/data/Dataset/MER2024/features/clip-vit-large-patch14-UTT/",
141 | "/mnt/public/share/data/Dataset/MER2024/features/chinese-hubert-large-UTT/",
142 | "/mnt/public/share/data/Dataset/MER2024/features/bloom-7b1-UTT/",]
143 |
144 | # --- save config
145 | save_root = "/mnt/public/gxj_2/EmoNet_Pro/lst_train/"
146 | save_name = "seed{}_MER24".format(random_seed)
147 |
148 | # --- run
149 | dataset_info = normalize_dataset_format(
150 | emo_rule,
151 | gt_path,
152 | video_root,
153 | feat_roots,
154 | )
155 |
156 | feat_dims = dataset_info["feat_dims"]
157 | feat_roots = dataset_info["feat_roots"]
158 | train_info = dataset_info["train_info"]
159 | valid_info = dataset_info["valid_info"]
160 |
161 | train_info["vals"] = [-100] * len(train_info["vals"])
162 | valid_info["vals"] = [-100] * len(valid_info["vals"])
163 |
164 | save_path = os.path.join(save_root, save_name + ".npy")
165 | if not os.path.exists(save_root):
166 | os.makedirs(save_root)
167 |
168 | np.save(
169 | save_path,
170 | {
171 | "feat_dims": feat_dims,
172 | "feat_roots": feat_roots,
173 | "train_info": train_info,
174 | "valid_info": valid_info,
175 | },
176 | )
177 |
178 | print("----------- summary")
179 | print(
180 | "split_train: {}\nsplit_valid: {}".format(
181 | len(train_info["names"]), len(valid_info["names"])
182 | )
183 | )
184 |
--------------------------------------------------------------------------------
/models/auto_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .modules.encoder import LSTMEncoder, MLPEncoder
5 |
6 |
7 | class Auto_Attention(nn.Module):
8 | def __init__(self, args):
9 | super(Auto_Attention, self).__init__()
10 | feat_dims = args.feat_dims
11 | output_dim1 = args.output_dim1
12 | output_dim2 = args.output_dim2
13 | dropout = args.dropout
14 | hidden_dim = args.hidden_dim
15 | self.grad_clip = args.grad_clip
16 | self.feat_dims = feat_dims
17 |
18 | assert args.feat_type == "utt"
19 |
20 | assert len(feat_dims) <= 3 * 6
21 | if len(feat_dims) >= 1:
22 | self.encoder0 = MLPEncoder(feat_dims[0], hidden_dim, dropout)
23 | if len(feat_dims) >= 2:
24 | self.encoder1 = MLPEncoder(feat_dims[1], hidden_dim, dropout)
25 | if len(feat_dims) >= 3:
26 | self.encoder2 = MLPEncoder(feat_dims[2], hidden_dim, dropout)
27 | if len(feat_dims) >= 4:
28 | self.encoder3 = MLPEncoder(feat_dims[3], hidden_dim, dropout)
29 | if len(feat_dims) >= 5:
30 | self.encoder4 = MLPEncoder(feat_dims[4], hidden_dim, dropout)
31 | if len(feat_dims) >= 6:
32 | self.encoder5 = MLPEncoder(feat_dims[5], hidden_dim, dropout)
33 | if len(feat_dims) >= 7:
34 | self.encoder6 = MLPEncoder(feat_dims[6], hidden_dim, dropout)
35 | if len(feat_dims) >= 8:
36 | self.encoder7 = MLPEncoder(feat_dims[7], hidden_dim, dropout)
37 | if len(feat_dims) >= 9:
38 | self.encoder8 = MLPEncoder(feat_dims[8], hidden_dim, dropout)
39 | if len(feat_dims) >= 10:
40 | self.encoder9 = MLPEncoder(feat_dims[9], hidden_dim, dropout)
41 | if len(feat_dims) >= 11:
42 | self.encoder10 = MLPEncoder(feat_dims[10], hidden_dim, dropout)
43 | if len(feat_dims) >= 12:
44 | self.encoder11 = MLPEncoder(feat_dims[11], hidden_dim, dropout)
45 | if len(feat_dims) >= 13:
46 | self.encoder12 = MLPEncoder(feat_dims[12], hidden_dim, dropout)
47 | if len(feat_dims) >= 14:
48 | self.encoder13 = MLPEncoder(feat_dims[13], hidden_dim, dropout)
49 | if len(feat_dims) >= 15:
50 | self.encoder14 = MLPEncoder(feat_dims[14], hidden_dim, dropout)
51 | if len(feat_dims) >= 16:
52 | self.encoder15 = MLPEncoder(feat_dims[15], hidden_dim, dropout)
53 | if len(feat_dims) >= 17:
54 | self.encoder16 = MLPEncoder(feat_dims[16], hidden_dim, dropout)
55 | if len(feat_dims) >= 18:
56 | self.encoder17 = MLPEncoder(feat_dims[17], hidden_dim, dropout)
57 |
58 | self.attention_mlp = MLPEncoder(
59 | hidden_dim * len(feat_dims), hidden_dim, dropout
60 | )
61 |
62 | self.fc_att = nn.Linear(hidden_dim, 3)
63 | self.fc_out_1 = nn.Linear(hidden_dim, output_dim1)
64 | self.fc_out_2 = nn.Linear(hidden_dim, output_dim2)
65 |
66 | def forward(self, batch):
67 | """
68 | support feat_type: utt | frm-align | frm-unalign
69 | """
70 | hiddens = []
71 | if len(self.feat_dims) >= 1:
72 | hiddens.append(self.encoder0(batch[f"feat0"]))
73 | if len(self.feat_dims) >= 2:
74 | hiddens.append(self.encoder1(batch[f"feat1"]))
75 | if len(self.feat_dims) >= 3:
76 | hiddens.append(self.encoder2(batch[f"feat2"]))
77 | if len(self.feat_dims) >= 4:
78 | hiddens.append(self.encoder3(batch[f"feat3"]))
79 | if len(self.feat_dims) >= 5:
80 | hiddens.append(self.encoder4(batch[f"feat4"]))
81 | if len(self.feat_dims) >= 6:
82 | hiddens.append(self.encoder5(batch[f"feat5"]))
83 | if len(self.feat_dims) >= 7:
84 | hiddens.append(self.encoder6(batch[f"feat6"]))
85 | if len(self.feat_dims) >= 8:
86 | hiddens.append(self.encoder7(batch[f"feat7"]))
87 | if len(self.feat_dims) >= 9:
88 | hiddens.append(self.encoder8(batch[f"feat8"]))
89 | if len(self.feat_dims) >= 10:
90 | hiddens.append(self.encoder9(batch[f"feat9"]))
91 | if len(self.feat_dims) >= 11:
92 | hiddens.append(self.encoder10(batch[f"feat10"]))
93 | if len(self.feat_dims) >= 12:
94 | hiddens.append(self.encoder11(batch[f"feat11"]))
95 | if len(self.feat_dims) >= 13:
96 | hiddens.append(self.encoder12(batch[f"feat12"]))
97 | if len(self.feat_dims) >= 14:
98 | hiddens.append(self.encoder13(batch[f"feat13"]))
99 | if len(self.feat_dims) >= 15:
100 | hiddens.append(self.encoder14(batch[f"feat14"]))
101 | if len(self.feat_dims) >= 16:
102 | hiddens.append(self.encoder15(batch[f"feat15"]))
103 | if len(self.feat_dims) >= 17:
104 | hiddens.append(self.encoder16(batch[f"feat16"]))
105 | if len(self.feat_dims) >= 18:
106 | hiddens.append(self.encoder17(batch[f"feat17"]))
107 |
108 | multi_hidden1 = torch.cat(hiddens, dim=1) # [32, 384]
109 | attention = self.attention_mlp(multi_hidden1)
110 | attention = self.fc_att(attention)
111 | attention = torch.unsqueeze(attention, 2) # [32, 3, 1]
112 |
113 | multi_hidden2 = torch.stack(hiddens, dim=2) # [32, 128, 3]
114 | fused_feat = torch.matmul(
115 | multi_hidden2, attention
116 | ) # [32, 128, 3] * [32, 3, 1] = [32, 128, 1]
117 |
118 | features = fused_feat.squeeze(axis=2) # [32, 128] => 解决batch=1报错的问题
119 | emos_out = self.fc_out_1(features)
120 | vals_out = self.fc_out_2(features)
121 | interloss = torch.tensor(0).cuda()
122 |
123 | return features, emos_out, vals_out, interloss
124 |
--------------------------------------------------------------------------------
/feature_extraction/text/split_asr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tqdm
3 | import glob
4 | import numpy as np
5 | import pandas as pd
6 | import sys
7 | import argparse
8 |
9 | current_file_path = os.path.abspath(__file__)
10 | sys.path.append(os.path.dirname(os.path.dirname(current_file_path)))
11 | import config
12 |
13 |
14 | # 功能3:从csv中读取特定的key对应的值
15 | def func_read_key_from_csv(csv_path, key):
16 | values = []
17 | df = pd.read_csv(csv_path)
18 | for _, row in df.iterrows():
19 | if key not in row:
20 | values.append("")
21 | else:
22 | value = row[key]
23 | if pd.isna(value):
24 | value = ""
25 | values.append(value)
26 | return values
27 |
28 |
29 | # names[ii] -> keys=name2key[names[ii]], containing keynames
30 | def func_write_key_to_csv(csv_path, names, name2key, keynames):
31 | ## specific case: only save names
32 | if len(name2key) == 0 or len(keynames) == 0:
33 | df = pd.DataFrame(data=names, columns=["name"])
34 | df.to_csv(csv_path, index=False)
35 | return
36 |
37 | ## other cases:
38 | if isinstance(keynames, str):
39 | keynames = [keynames]
40 | assert isinstance(keynames, list)
41 | columns = ["name"] + keynames
42 |
43 | values = []
44 | for name in names:
45 | value = name2key[name]
46 | values.append(value)
47 | values = np.array(values)
48 | # ensure keynames is mapped
49 | if len(values.shape) == 1:
50 | assert len(keynames) == 1
51 | else:
52 | assert values.shape[-1] == len(keynames)
53 | data = np.column_stack([names, values])
54 |
55 | df = pd.DataFrame(data=data, columns=columns)
56 | df.to_csv(csv_path, index=False)
57 |
58 |
59 | # python main-asr.py generate_transcription_files_asr ./dataset-process/audio ./dataset-process/transcription.csv
60 | def generate_transcription_files_asr(audio_root, save_path):
61 | import torch
62 |
63 | # import wenetruntime as wenet # must load torch first
64 | import wenet
65 |
66 | # from paddlespeech.cli.text.infer import TextExecutor
67 | # text_punc = TextExecutor()
68 | # decoder = wenet.Decoder(config.PATH_TO_WENET, lang='chs')
69 | decoder = wenet.load_model(language="chinese")
70 |
71 | names = []
72 | sentences = []
73 | for audio_path in tqdm.tqdm(glob.glob(audio_root + "/*")):
74 | name = os.path.basename(audio_path)[:-4]
75 | # sentence = decoder.decode_wav(audio_path)
76 | # sentence = sentence.split('"')[5]
77 | sentence = decoder.transcribe(audio_path)
78 | sentence = sentence["text"]
79 | # if len(sentence) > 0: sentence = text_punc(text=sentence)
80 | names.append(name)
81 | sentences.append(sentence)
82 |
83 | ## write to csv file
84 | columns = ["name", "sentence"]
85 | data = np.column_stack([names, sentences])
86 | df = pd.DataFrame(data=data, columns=columns)
87 | df[columns] = df[columns].astype(str)
88 | df.to_csv(save_path, index=False)
89 |
90 |
91 | # python main-asr.py refinement_transcription_files_asr(old_path, new_path)
92 | def refinement_transcription_files_asr(old_path, new_path):
93 | from paddlespeech.cli.text.infer import TextExecutor
94 |
95 | text_punc = TextExecutor()
96 |
97 | ## read
98 | names, sentences = [], []
99 | df_label = pd.read_csv(old_path)
100 | for _, row in df_label.iterrows(): ## read for each row
101 | names.append(row["name"])
102 | sentence = row["sentence"]
103 | if pd.isna(sentence):
104 | sentences.append("")
105 | else:
106 | sentence = text_punc(text=sentence)
107 | sentences.append(sentence)
108 | print(sentences[-1])
109 |
110 | ## write
111 | columns = ["name", "chinese"]
112 | data = np.column_stack([names, sentences])
113 | df = pd.DataFrame(data=data, columns=columns)
114 | df[columns] = df[columns].astype(str)
115 | df.to_csv(new_path, index=False)
116 |
117 |
118 | # python main-asr.py merge_trans_with_checked dataset/mer2024-dataset-process/transcription.csv dataset/mer2024-dataset/label-transcription.csv dataset/mer2024-dataset-process/transcription-merge.csv
119 | def merge_trans_with_checked(new_path, check_path, merge_path):
120 |
121 | # read new_path 7369
122 | name2new = {}
123 | names = func_read_key_from_csv(new_path, "name")
124 | trans = func_read_key_from_csv(new_path, "sentence")
125 | for name, tran in zip(names, trans):
126 | name2new[name] = tran
127 | print(f"new sample: {len(name2new)}")
128 |
129 | # read check_path 5030
130 | name2check = {}
131 | names = func_read_key_from_csv(check_path, "name")
132 | trans = func_read_key_from_csv(check_path, "chinese")
133 | for name, tran in zip(names, trans):
134 | name2check[name] = tran
135 | print(f"check sample: {len(name2check)}")
136 |
137 | # 生成新的merge结果
138 | name2merge = {}
139 | for name in name2new:
140 | if name in name2check:
141 | name2merge[name] = [name2check[name]]
142 | else:
143 | name2merge[name] = [name2new[name]]
144 | print(f"merge sample: {len(name2merge)}")
145 |
146 | # 存储 name2merge
147 | names = [name for name in name2merge]
148 | keynames = ["chinese"]
149 | func_write_key_to_csv(merge_path, names, name2merge, keynames)
150 |
151 |
152 | if __name__ == "__main__":
153 | parser = argparse.ArgumentParser()
154 | parser.add_argument(
155 | "--dataset",
156 | type=str,
157 | default="/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/",
158 | help="file name",
159 | )
160 | args = parser.parse_args()
161 |
162 | dataset = args.dataset
163 |
164 | audio_root = os.path.join(dataset, "audio")
165 | save_root = os.path.join(dataset, "text")
166 |
167 | if not os.path.exists(save_root):
168 | os.makedirs(save_root)
169 |
170 | generate_transcription_files_asr(
171 | audio_root,
172 | os.path.join(save_root, "transcription-old.csv"),
173 | )
174 |
175 | refinement_transcription_files_asr(
176 | os.path.join(save_root, "transcription-old.csv"),
177 | os.path.join(save_root, "transcription.csv"),
178 | )
179 |
--------------------------------------------------------------------------------
/toolkit/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 | import numpy as np
5 |
6 |
7 | # classification loss
8 | class CELoss(nn.Module):
9 |
10 | def __init__(self):
11 | super(CELoss, self).__init__()
12 | self.loss = nn.NLLLoss(reduction="sum")
13 |
14 | def forward(self, pred, target):
15 | pred = F.log_softmax(pred, 1) # [n_samples, n_classes]
16 | target = target.long() # [n_samples]
17 | loss = self.loss(pred, target) / len(pred)
18 | return loss
19 |
20 |
21 | # regression loss
22 | class MSELoss(nn.Module):
23 |
24 | def __init__(self):
25 | super(MSELoss, self).__init__()
26 | self.loss = nn.MSELoss(reduction="sum")
27 |
28 | def forward(self, pred, target):
29 | pred = pred.view(-1, 1)
30 | target = target.view(-1, 1)
31 | loss = self.loss(pred, target) / len(pred)
32 | return loss
33 |
34 |
35 | class CenterLoss(nn.Module):
36 | def __init__(self, num_classes, feat_dim, lambda_c=0.5):
37 | super(CenterLoss, self).__init__()
38 | self.num_classes = num_classes # 类别数
39 | self.feat_dim = feat_dim # 特征维度
40 | self.lambda_c = lambda_c # 平衡系数
41 | # 初始化每个类别的特征中心
42 | # self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
43 | self.centers = nn.Parameter(torch.FloatTensor(num_classes, feat_dim))
44 | nn.init.xavier_uniform_(self.centers)
45 |
46 | def forward(self, x, labels):
47 | """
48 | x: 当前批次样本的特征向量 (batch_size, feat_dim)
49 | labels: 当前批次样本的类别标签 (batch_size)
50 | """
51 | batch_size = x.size(0)
52 |
53 | # 取出当前批次样本对应类别的中心
54 | centers_batch = self.centers.cuda().index_select(0, labels)
55 |
56 | # 计算 Center Loss
57 | center_loss = (
58 | self.lambda_c * 0.5 * torch.sum((x - centers_batch) ** 2) / batch_size
59 | )
60 | return center_loss
61 |
62 |
63 | class MultiClassFocalLossWithAlpha(nn.Module):
64 | def __init__(self, alpha=[0.2, 0.3, 0.5], gamma=2, reduction="mean", classnum=None):
65 | """
66 | :param alpha: 权重系数列表,三分类中第0类权重0.2,第1类权重0.3,第2类权重0.5
67 | :param gamma: 困难样本挖掘的gamma
68 | :param reduction:
69 | """
70 | super(MultiClassFocalLossWithAlpha, self).__init__()
71 | if classnum is None:
72 | self.alpha = torch.tensor(alpha)
73 | else:
74 | self.alpha = torch.tensor([1.0 / classnum] * classnum)
75 | self.gamma = gamma
76 | self.reduction = reduction
77 |
78 | def forward(self, pred, target):
79 | alpha = self.alpha.cuda()[
80 | target
81 | ] # 为当前batch内的样本,逐个分配类别权重,shape=(bs), 一维向量
82 | log_softmax = torch.log_softmax(
83 | pred, dim=1
84 | ) # 对模型裸输出做softmax再取log, shape=(bs, 3)
85 | logpt = torch.gather(
86 | log_softmax, dim=1, index=target.view(-1, 1)
87 | ) # 取出每个样本在类别标签位置的log_softmax值, shape=(bs, 1)
88 | logpt = logpt.view(-1) # 降维,shape=(bs)
89 | ce_loss = -logpt # 对log_softmax再取负,就是交叉熵了
90 | pt = torch.exp(
91 | logpt
92 | ) # 对log_softmax取exp,把log消了,就是每个样本在类别标签位置的softmax值了,shape=(bs)
93 | focal_loss = (
94 | alpha * (1 - pt) ** self.gamma * ce_loss
95 | ) # 根据公式计算focal loss,得到每个样本的loss值,shape=(bs)
96 | if self.reduction == "mean":
97 | return torch.mean(focal_loss)
98 | if self.reduction == "sum":
99 | return torch.sum(focal_loss)
100 | return focal_loss
101 |
102 |
103 | class LDAMLoss(nn.Module):
104 | def __init__(self, cls_num_list, device, max_m=0.5, weight=None, s=30):
105 | super(LDAMLoss, self).__init__()
106 | print("LDAM weights:", weight)
107 | m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list))
108 | m_list = m_list * (max_m / np.max(m_list))
109 | m_list = torch.cuda.FloatTensor(m_list)
110 | self.m_list = m_list
111 | self.device = device
112 | assert s > 0
113 | self.s = s
114 | self.weight = weight
115 |
116 | def forward(self, x, target):
117 | index = torch.zeros_like(x, dtype=torch.uint8)
118 | index.scatter_(1, target.data.view(-1, 1), 1)
119 |
120 | index_float = index.type(torch.cuda.FloatTensor)
121 | batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0, 1))
122 | batch_m = batch_m.view((-1, 1))
123 | x_m = x - batch_m
124 |
125 | output = torch.where(index, x_m, x)
126 | return F.cross_entropy(self.s * output, target, weight=self.weight)
127 | # criterion = LDAMLoss(cls_num_list=a list of numer of samples in each class, max_m=0.5, s=30, weight=per_cls_weights)
128 | """
129 | max_m: represents the margin used in the loss function. It controls the separation between different classes in the feature space.
130 | The appropriate value for max_m depends on the specific dataset and the severity of the class imbalance.
131 | You can start with a small value and gradually increase it to observe the impact on the model's performance.
132 | If the model struggles with class separation or experiences underfitting, increasing max_m might help. However,
133 | be cautious not to set it too high, as it can cause overfitting or make the model too conservative.
134 |
135 | s: higher s value results in sharper probabilities (more confident predictions), while a lower s value
136 | leads to more balanced probabilities. In practice, a larger s value can help stabilize training and improve convergence,
137 | especially when dealing with difficult optimization landscapes.
138 | The choice of s depends on the desired scale of the logits and the specific requirements of your problem.
139 | It can be used to adjust the balance between the margin and the original logits. A larger s value amplifies
140 | the impact of the logits and can be useful when dealing with highly imbalanced datasets.
141 | You can experiment with different values of s to find the one that works best for your dataset and model.
142 |
143 | """
144 |
145 |
146 | class LMFLoss(nn.Module):
147 | def __init__(
148 | self,
149 | cls_num_list,
150 | device,
151 | weight=None,
152 | alpha=0.2,
153 | beta=0.2,
154 | gamma=2,
155 | max_m=0.8,
156 | s=5,
157 | add_LDAM_weigth=False,
158 | ):
159 | super().__init__()
160 | self.focal_loss = MultiClassFocalLossWithAlpha(classnum=len(cls_num_list))
161 | if add_LDAM_weigth:
162 | LDAM_weight = weight
163 | else:
164 | LDAM_weight = None
165 | print(
166 | "LMF loss: alpha: ",
167 | alpha,
168 | " beta: ",
169 | beta,
170 | " gamma: ",
171 | gamma,
172 | " max_m: ",
173 | max_m,
174 | " s: ",
175 | s,
176 | " LDAM_weight: ",
177 | add_LDAM_weigth,
178 | )
179 | self.ldam_loss = LDAMLoss(cls_num_list, device, max_m, weight=LDAM_weight, s=s)
180 | self.alpha = alpha
181 | self.beta = beta
182 |
183 | def forward(self, output, target):
184 | focal_loss_output = self.focal_loss(output, target)
185 | ldam_loss_output = self.ldam_loss(output, target)
186 | total_loss = self.alpha * focal_loss_output + self.beta * ldam_loss_output
187 | return total_loss
188 |
--------------------------------------------------------------------------------
/feature_extraction/audio/extract_audio_huggingface.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import math
4 | import os
5 |
6 | # import config
7 | import sys
8 | import time
9 |
10 | import numpy as np
11 | import soundfile as sf
12 | import torch
13 |
14 | current_file_path = os.path.abspath(__file__)
15 | sys.path.append(os.path.dirname(os.path.dirname(current_file_path)))
16 | from transformers import AutoModel, Wav2Vec2FeatureExtractor, WhisperFeatureExtractor
17 |
18 | import config
19 |
20 | # supported models
21 | ################## ENGLISH ######################
22 | WAV2VEC2_BASE = (
23 | "wav2vec2-base-960h" # https://huggingface.co/facebook/wav2vec2-base-960h
24 | )
25 | WAV2VEC2_LARGE = (
26 | "wav2vec2-large-960h" # https://huggingface.co/facebook/wav2vec2-large-960h
27 | )
28 | DATA2VEC_AUDIO_BASE = "data2vec-audio-base-960h" # https://huggingface.co/facebook/data2vec-audio-base-960h
29 | DATA2VEC_AUDIO_LARGE = (
30 | "data2vec-audio-large" # https://huggingface.co/facebook/data2vec-audio-large
31 | )
32 |
33 | ################## CHINESE ######################
34 | HUBERT_BASE_CHINESE = (
35 | "chinese-hubert-base" # https://huggingface.co/TencentGameMate/chinese-hubert-base
36 | )
37 | HUBERT_LARGE_CHINESE = "chinese-hubert-large" # https://huggingface.co/TencentGameMate/chinese-hubert-large
38 | WAV2VEC2_BASE_CHINESE = "chinese-wav2vec2-base" # https://huggingface.co/TencentGameMate/chinese-wav2vec2-base
39 | WAV2VEC2_LARGE_CHINESE = "chinese-wav2vec2-large" # https://huggingface.co/TencentGameMate/chinese-wav2vec2-large
40 |
41 | ################## Multilingual #################
42 | WAVLM_BASE = "wavlm-base" # https://huggingface.co/microsoft/wavlm-base
43 | WAVLM_LARGE = "wavlm-large" # https://huggingface.co/microsoft/wavlm-large
44 | WHISPER_BASE = "whisper-base" # https://huggingface.co/openai/whisper-base
45 | WHISPER_LARGE = "whisper-large-v2" # https://huggingface.co/openai/whisper-large-v2
46 |
47 |
48 | ## Target: avoid too long inputs
49 | # input_values: [1, wavlen], output: [bsize, maxlen]
50 | def split_into_batch(input_values, maxlen=16000 * 10):
51 | if len(input_values[0]) <= maxlen:
52 | return input_values
53 |
54 | bs, wavlen = input_values.shape
55 | assert bs == 1
56 | tgtlen = math.ceil(wavlen / maxlen) * maxlen
57 | batches = torch.zeros((1, tgtlen))
58 | batches[:, :wavlen] = input_values
59 | batches = batches.view(-1, maxlen)
60 | return batches
61 |
62 |
63 | def extract(model_name, audio_files, save_dir, feature_level, gpu):
64 |
65 | start_time = time.time()
66 |
67 | # load model
68 | model_file = os.path.join(
69 | config.PATH_TO_PRETRAINED_MODELS, f"transformers/{model_name}"
70 | )
71 |
72 | if model_name in [WHISPER_BASE, WHISPER_LARGE]:
73 | model = AutoModel.from_pretrained(model_file)
74 | feature_extractor = WhisperFeatureExtractor.from_pretrained(model_file)
75 | else:
76 | model = AutoModel.from_pretrained(model_file)
77 | feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_file)
78 |
79 | load_time = time.time()
80 |
81 | if gpu != -1:
82 | device = torch.device(f"cuda:{gpu}")
83 | model.to(device)
84 | model.eval()
85 |
86 | # iterate audios
87 | t1 = time.time()
88 | for idx, audio_file in enumerate(audio_files, 1):
89 | file_name = os.path.basename(audio_file)
90 | vid = file_name[:-4]
91 | print(f'Processing "{file_name}" ({idx}/{len(audio_files)})...')
92 |
93 | ## process for too short ones
94 | samples, sr = sf.read(audio_file)
95 | assert sr == 16000, "currently, we only test on 16k audio"
96 |
97 | ## model inference
98 | with torch.no_grad():
99 | if model_name in [WHISPER_BASE, WHISPER_LARGE]:
100 | layer_ids = [-1]
101 | input_features = feature_extractor(
102 | samples, sampling_rate=sr, return_tensors="pt"
103 | ).input_features # [1, 80, 3000]
104 | decoder_input_ids = (
105 | torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
106 | )
107 | if gpu != -1:
108 | input_features = input_features.to(device)
109 | if gpu != -1:
110 | decoder_input_ids = decoder_input_ids.to(device)
111 | last_hidden_state = model(
112 | input_features, decoder_input_ids=decoder_input_ids
113 | ).last_hidden_state
114 | assert last_hidden_state.shape[0] == 1
115 | feature = (
116 | last_hidden_state[0].detach().squeeze().cpu().numpy()
117 | ) # (2, D)
118 | else:
119 | layer_ids = [-4, -3, -2, -1]
120 | input_values = feature_extractor(
121 | samples, sampling_rate=sr, return_tensors="pt"
122 | ).input_values # [1, wavlen]
123 | input_values = split_into_batch(
124 | input_values
125 | ) # [bsize, maxlen=10*16000]
126 | if gpu != -1:
127 | input_values = input_values.to(device)
128 | hidden_states = model(
129 | input_values, output_hidden_states=True
130 | ).hidden_states # tuple of (B, T, D)
131 | feature = torch.stack(hidden_states)[layer_ids].sum(
132 | dim=0
133 | ) # (B, T, D) # -> compress waveform channel
134 | bsize, segnum, featdim = feature.shape
135 | feature = (
136 | feature.view(-1, featdim).detach().squeeze().cpu().numpy()
137 | ) # (B*T, D)
138 |
139 | ## store values
140 | csv_file = os.path.join(save_dir, f"{vid}.npy")
141 | if feature_level == "UTTERANCE":
142 | feature = np.array(feature).squeeze()
143 | if len(feature.shape) != 1:
144 | feature = np.mean(feature, axis=0)
145 | np.save(csv_file, feature)
146 | else:
147 | np.save(csv_file, feature)
148 |
149 | t2 = time.time()
150 | end_time = time.time()
151 | print(f"Model load time used: {load_time - start_time:.1f}s.")
152 | print(f"Total time used: {end_time - start_time:.1f}s.")
153 | print(f"average time used: {(t2 - t1)/len(audio_files):.1f}s.")
154 |
155 |
156 | if __name__ == "__main__":
157 |
158 | parser = argparse.ArgumentParser(description="Run.")
159 | parser.add_argument("--gpu", type=int, default=0, help="index of gpu")
160 | parser.add_argument(
161 | "--model_name",
162 | type=str,
163 | default="chinese-hubert-large",
164 | help="feature extractor",
165 | )
166 | parser.add_argument(
167 | "--feature_level", type=str, default="FRAME", help="FRAME or UTTERANCE"
168 | )
169 | parser.add_argument(
170 | "--dataset",
171 | type=str,
172 | default="/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/",
173 | help="input dataset",
174 | )
175 | # ------ 临时测试SNR对于结果的影响 ------
176 | parser.add_argument(
177 | "--noise_case",
178 | type=str,
179 | default=None,
180 | help="extract feature of different noise conditions",
181 | )
182 | # ------ 临时测试 tts audio 对于结果的影响 -------
183 | parser.add_argument(
184 | "--tts_lang",
185 | type=str,
186 | default=None,
187 | help="extract feature from tts audio, [chinese, english]",
188 | )
189 | args = parser.parse_args()
190 |
191 | # analyze input
192 | audio_dir = os.path.join(args.dataset, "audio")
193 | save_dir = os.path.join(args.dataset, "features")
194 |
195 | if args.noise_case is not None:
196 | audio_dir += "_" + args.noise_case
197 | if args.tts_lang is not None:
198 | audio_dir += "-" + f"tts{args.tts_lang[:3]}16k"
199 |
200 | # audio_files
201 | audio_files = glob.glob(os.path.join(audio_dir, "*.wav"))
202 | print(f'Find total "{len(audio_files)}" audio files.')
203 |
204 | # save_dir
205 | if args.noise_case is not None:
206 | dir_name = f"{args.model_name}-noise{args.noise_case}-{args.feature_level[:3]}"
207 | elif args.tts_lang is not None:
208 | dir_name = f"{args.model_name}-tts{args.tts_lang[:3]}-{args.feature_level[:3]}"
209 | else:
210 | dir_name = f"{args.model_name}-{args.feature_level[:3]}"
211 |
212 | save_dir = os.path.join(save_dir, dir_name)
213 | if not os.path.exists(save_dir):
214 | os.makedirs(save_dir)
215 |
216 | # extract features
217 | extract(args.model_name, audio_files, save_dir, args.feature_level, gpu=args.gpu)
218 |
--------------------------------------------------------------------------------
/feature_extraction/visual/extract_openface.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import glob
4 | import shutil
5 | import pathlib
6 | import argparse
7 | import numpy as np
8 | from util import read_hog, read_csv
9 |
10 | import sys
11 |
12 | current_file_path = os.path.abspath(__file__)
13 | sys.path.append(os.path.dirname(os.path.dirname((current_file_path))))
14 |
15 | import config
16 |
17 |
18 | def generate_face_faceDir(input_root, save_root, savetype="image"):
19 | if savetype == "image":
20 | for dir_path in sorted(
21 | glob.glob(input_root + "/*_aligned")
22 | ): # 'xx/xx/000100_guest_aligned'
23 | frame_names = os.listdir(dir_path) # ['xxx.bmp']
24 | if len(frame_names) != 1:
25 | continue
26 | frame_path = os.path.join(
27 | dir_path, frame_names[0]
28 | ) # 'xx/xx/000100_guest_aligned/xxx.bmp'
29 | name = os.path.basename(dir_path)[: -len("_aligned")] # '000100_guest'
30 | save_path = os.path.join(save_root, name + ".bmp")
31 | shutil.copy(frame_path, save_path)
32 | elif savetype == "npy":
33 | frames = []
34 | for dir_path in sorted(
35 | glob.glob(input_root + "/*_aligned")
36 | ): # 'xx/xx/000100_guest_aligned'
37 | frame_names = os.listdir(dir_path) # ['xxx.bmp']
38 | if len(frame_names) != 1:
39 | continue
40 | frame_path = os.path.join(
41 | dir_path, frame_names[0]
42 | ) # 'xx/xx/000100_guest_aligned/xxx.bmp'
43 | frame = cv2.imread(frame_path)
44 | frames.append(frame)
45 | videoname = os.path.basename(input_root)
46 | save_path = os.path.join(save_root, videoname + ".npy")
47 | np.save(save_path, frames)
48 |
49 |
50 | def generate_face_videoOne(input_root, save_root, savetype="image"):
51 | for dir_path in glob.glob(
52 | input_root + "/*_aligned"
53 | ): # 'xx/xx/000100_guest_aligned'
54 | frame_names = sorted(os.listdir(dir_path)) # ['xxx.bmp']
55 | if savetype == "image":
56 | for ii in range(len(frame_names)):
57 | frame_path = os.path.join(
58 | dir_path, frame_names[ii]
59 | ) # 'xx/xx/000100_guest_aligned/xxx.bmp'
60 | frame_name = os.path.basename(frame_path)
61 | save_path = os.path.join(save_root, frame_name)
62 | shutil.copy(frame_path, save_path)
63 | elif savetype == "npy":
64 | frames = []
65 | for ii in range(len(frame_names)):
66 | frame_path = os.path.join(dir_path, frame_names[ii])
67 | frame = cv2.imread(frame_path)
68 | frames.append(frame)
69 | videoname = os.path.basename(input_root)
70 | save_path = os.path.join(save_root, videoname + ".npy")
71 | np.save(save_path, frames)
72 |
73 |
74 | def generate_hog(input_root, save_root):
75 | for hog_path in glob.glob(input_root + "/*.hog"):
76 | csv_path = hog_path[:-4] + ".csv"
77 | if os.path.exists(csv_path):
78 | hog_name = os.path.basename(hog_path)[:-4]
79 | _, feature = read_hog(hog_path)
80 | save_path = os.path.join(save_root, hog_name + ".npy")
81 | np.save(save_path, feature)
82 |
83 |
84 | def generate_csv(input_root, save_root, startIdx):
85 | for csv_path in glob.glob(input_root + "/*.csv"):
86 | csv_name = os.path.basename(csv_path)[:-4]
87 | feature = read_csv(csv_path, startIdx)
88 | save_path = os.path.join(save_root, csv_name + ".npy")
89 | np.save(save_path, feature)
90 |
91 |
92 | # name_npy: only process on names in 'name_npy'
93 | def extract(
94 | input_dir, process_type, save_dir, face_dir, hog_dir, pose_dir, name_npy=None
95 | ):
96 |
97 | # => process_names
98 | if name_npy is not None: # 指定特定的文件进行处理
99 | process_names = np.load(name_npy)
100 | else: # 处理所有视频文件
101 | vids = os.listdir(input_dir)
102 | process_names = [vid.rsplit(".", 1)[0] for vid in vids]
103 | print(f"processing names: {len(process_names)}")
104 |
105 | # process folders
106 | vids = os.listdir(input_dir)
107 | print(f'Find total "{len(vids)}" videos.')
108 | for i, vid in enumerate(vids, 1):
109 | saveVid = vid.rsplit(".", 1)[0] # unify folder and video names
110 | if saveVid not in process_names:
111 | continue
112 |
113 | print(f"Processing video '{vid}' ({i}/{len(vids)})...")
114 | input_root = os.path.join(input_dir, vid) # exists
115 | save_root = os.path.join(save_dir, saveVid)
116 | face_root = os.path.join(face_dir, saveVid)
117 | hog_root = os.path.join(hog_dir, saveVid)
118 | pose_root = os.path.join(pose_dir, saveVid)
119 | if os.path.exists(face_root):
120 | continue
121 | if not os.path.exists(save_root):
122 | os.makedirs(save_root)
123 | if not os.path.exists(face_root):
124 | os.makedirs(face_root)
125 | if not os.path.exists(hog_root):
126 | os.makedirs(hog_root)
127 | if not os.path.exists(pose_root):
128 | os.makedirs(pose_root)
129 | if process_type == "faceDir":
130 | exe_path = os.path.join(config.PATH_TO_OPENFACE_Win, "FaceLandmarkImg.exe")
131 | commond = '%s -fdir "%s" -out_dir "%s"' % (exe_path, input_root, save_root)
132 | os.system(commond)
133 | # generate_face_faceDir(save_root, face_root, savetype='image') # more subtle files
134 | generate_face_faceDir(
135 | save_root, face_root, savetype="npy"
136 | ) # compress frames into npy
137 | # generate_hog(save_root, hog_root) # not used
138 | # generate_csv(save_root, pose_root, startIdx=2) # not used
139 | ## delete temp folder
140 | dir_path = pathlib.Path(save_root)
141 | shutil.rmtree(dir_path)
142 | elif process_type == "videoOne":
143 | exe_path = os.path.join(
144 | config.PATH_TO_OPENFACE_Win, "build/bin/FeatureExtraction"
145 | ) # jxie
146 | commond = '%s -f "%s" -out_dir "%s"' % (exe_path, input_root, save_root)
147 | os.system(commond)
148 | # generate_face_videoOne(save_root, face_root, savetype='image') # more subtle files
149 | generate_face_videoOne(
150 | save_root, face_root, savetype="npy"
151 | ) # compress frames into npy
152 | # generate_hog(save_root, hog_root)
153 | # generate_csv(save_root, pose_root, startIdx=5)
154 | ## delete temp folder
155 | dir_path = pathlib.Path(save_root)
156 | shutil.rmtree(dir_path)
157 |
158 |
159 | if __name__ == "__main__":
160 | parser = argparse.ArgumentParser(description="Run.")
161 | parser.add_argument(
162 | "--overwrite",
163 | action="store_true",
164 | default=True,
165 | help="whether overwrite existed feature folder.",
166 | )
167 | parser.add_argument(
168 | "--dataset",
169 | type=str,
170 | default="/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/",
171 | help="input dataset dir path",
172 | )
173 | parser.add_argument("--name_npy", type=str, default=None, help="process name lists")
174 | parser.add_argument(
175 | "--type",
176 | type=str,
177 | default="videoOne",
178 | choices=["faceDir", "videoOne"],
179 | help="faceDir: process on facedirs; videoOne: process on one video",
180 | )
181 | params = parser.parse_args()
182 |
183 | print(f"==> Extracting openface features...")
184 |
185 | # in: face dir
186 | dataset = params.dataset
187 | process_type = params.type
188 | input_dir = os.path.join(dataset, "video")
189 |
190 | # out: feature csv dir
191 | save_dir = os.path.join(dataset, "features", "openface_all")
192 | hog_dir = os.path.join(dataset, "features", "openface_hog")
193 | pose_dir = os.path.join(dataset, "features", "openface_pose")
194 | face_dir = os.path.join(dataset, "features", "openface_face")
195 |
196 | if not os.path.exists(save_dir):
197 | os.makedirs(save_dir)
198 | if not os.path.exists(hog_dir):
199 | os.makedirs(hog_dir)
200 | if not os.path.exists(pose_dir):
201 | os.makedirs(pose_dir)
202 | if not os.path.exists(face_dir):
203 | os.makedirs(face_dir)
204 |
205 | # process
206 | extract(
207 | input_dir, process_type, save_dir, face_dir, hog_dir, pose_dir, params.name_npy
208 | )
209 |
210 | print(f"==> Finish")
211 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import logging
4 | import os
5 | import os.path as osp
6 | import random
7 |
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | import torch
11 | import torch.optim as optim
12 | from omegaconf import OmegaConf
13 | from sklearn.metrics import (
14 | accuracy_score,
15 | confusion_matrix,
16 | f1_score,
17 | precision_score,
18 | recall_score,
19 | )
20 | from tqdm import tqdm
21 |
22 | import config
23 | from dataloader.dataloader import get_dataloader_train, get_dataloader_valid
24 | from models import get_models
25 | from toolkit.utils.draw_process import draw_loss, draw_metric
26 | from toolkit.utils.eval import calculate_results
27 | from toolkit.utils.functions import func_update_storage, merge_args_config
28 | from toolkit.utils.loss import CELoss, MSELoss
29 | from toolkit.utils.metric import gain_metric_from_results
30 |
31 | # emotion rule
32 | # emotions = ["neutral", "angry", "happy", "sad", "worried", "surprise"]
33 |
34 |
35 | def set_seed(seed=0):
36 | torch.manual_seed(seed)
37 | torch.cuda.manual_seed(seed)
38 | torch.cuda.manual_seed_all(seed)
39 | np.random.seed(seed)
40 | random.seed(seed)
41 | torch.backends.cudnn.benchmark = False
42 | torch.backends.cudnn.deterministic = True
43 |
44 |
45 | def test_model(
46 | args,
47 | model,
48 | dataloader,
49 | ):
50 |
51 | vidnames = []
52 | val_preds, val_labels = [], []
53 | emo_probs, emo_labels = [], []
54 | losses = []
55 |
56 | model.eval()
57 |
58 | pbar = tqdm(range(len(dataloader)), desc=f"test")
59 |
60 | for iter, data in enumerate(dataloader):
61 |
62 | # read data + cuda
63 | audios, texts, videos, emos, vals, bnames = data
64 |
65 | vidnames += bnames
66 |
67 | batch = {}
68 | batch["videos"] = videos.float().cuda()
69 | batch["audios"] = audios.float().cuda()
70 | batch["texts"] = texts.float().cuda()
71 |
72 | emos = emos.long().cuda()
73 | vals = vals.float().cuda()
74 |
75 | if args.train_input_mode == "input_gt":
76 | _, emos_out, _, _ = model([batch, emos])
77 | elif args.train_input_mode == "input":
78 | _, emos_out, _, _ = model(batch)
79 |
80 | emo_probs.append(emos_out.data.cpu().numpy())
81 | emo_labels.append(emos.data.cpu().numpy())
82 |
83 | pbar.update(1)
84 |
85 | pbar.close()
86 |
87 | if emo_probs != []:
88 | emo_probs = np.concatenate(emo_probs)
89 | if emo_labels != []:
90 | emo_labels = np.concatenate(emo_labels)
91 | if val_preds != []:
92 | val_preds = np.concatenate(val_preds)
93 | if val_labels != []:
94 | val_labels = np.concatenate(val_labels)
95 | results, _ = calculate_results(emo_probs, emo_labels, val_preds, val_labels)
96 | save_results = dict(
97 | names=vidnames,
98 | **results,
99 | )
100 |
101 | y_true = []
102 | y_pred = []
103 | emo_preds = np.argmax(emo_probs, 1)
104 | for emo_pred, emo_label in zip(emo_preds, emo_labels):
105 | y_pred.append(emotions[emo_pred])
106 | y_true.append(emotions[emo_label])
107 | save_results["emotions"] = emotions
108 | conf_matrix = confusion_matrix(y_true, y_pred, labels=emotions)
109 | accuracy = accuracy_score(y_true, y_pred)
110 | precision = precision_score(y_true, y_pred, average="weighted")
111 | recall = recall_score(y_true, y_pred, average="weighted")
112 | f1 = f1_score(y_true, y_pred, average="weighted")
113 | print("emotions: " + str(emotions))
114 | print("Confusion Matrix:")
115 | print(str(conf_matrix))
116 | print(f"Accuracy: {accuracy:.4f}")
117 | print(f"Precision: {precision:.4f}")
118 | print(f"Recall: {recall:.4f}")
119 | print(f"F1 Score: {f1:.4f}")
120 | return save_results
121 |
122 |
123 | if __name__ == "__main__":
124 | parser = argparse.ArgumentParser()
125 | parser.add_argument(
126 | "--seed",
127 | type=int,
128 | default=0,
129 | help="random seed",
130 | )
131 |
132 | # --- data config
133 | parser.add_argument(
134 | "--dataset",
135 | type=str,
136 | default="MER24-test_3A_whisper-base-UTT",
137 | help="dataset info name",
138 | )
139 | parser.add_argument(
140 | "--emo_rule",
141 | type=str,
142 | default="MER",
143 | help="emo map function from emotion to index",
144 | )
145 |
146 | # --- save config
147 | parser.add_argument(
148 | "--save_root",
149 | type=str,
150 | default="/mnt/public/gxj/EmoNets/saved_test",
151 | help="save prediction results and models",
152 | )
153 | parser.add_argument(
154 | "--save_name",
155 | type=str,
156 | default="test",
157 | help="save prediction results and models",
158 | )
159 |
160 | # --- model config
161 | parser.add_argument(
162 | "--model",
163 | type=str,
164 | default="attention",
165 | help="model name for training [attention, mer_rank5, and others]",
166 | )
167 | parser.add_argument(
168 | "--load_key",
169 | type=str,
170 | default="/mnt/public/gxj/EmoNets/saved/seed0_MER24_3A_whisper-base-UTT/attention/2025-04-06-21-26-53",
171 | help="keyword about which model weight to load",
172 | )
173 |
174 | # --- test sets
175 | parser.add_argument(
176 | "--batch_size",
177 | type=int,
178 | default=512,
179 | metavar="BS",
180 | help="batch size",
181 | )
182 | parser.add_argument(
183 | "--num_workers", type=int, default=4, metavar="nw", help="number of workers"
184 | )
185 | parser.add_argument("--gpu", default=0, type=int, help="GPU id to use")
186 |
187 | args = parser.parse_args()
188 |
189 | set_seed(args.seed)
190 |
191 | torch.cuda.set_device(args.gpu)
192 |
193 | # 若没有关键词,则使用load key的路径
194 | if args.load_key not in config.MODEL_DIR_DICT[args.model].keys():
195 | config.MODEL_DIR_DICT[args.model][args.load_key] = args.load_key
196 |
197 | # 用当前设置参数覆盖训练时的同名设置
198 | load_args_path = osp.join(
199 | config.MODEL_DIR_DICT[args.model][args.load_key], "best_args.npy"
200 | )
201 | load_args = np.load(load_args_path, allow_pickle=True).item()["args"]
202 | load_args_dic = vars(load_args)
203 | args_dic = vars(args)
204 | for key in args_dic:
205 | load_args_dic[key] = args_dic[key]
206 | args = argparse.Namespace(**load_args_dic) # 两部分参数重叠
207 |
208 | print("====== set save dir =======")
209 | now = datetime.datetime.now()
210 | save_dir = os.path.join(args.save_root, args.dataset, args.model, args.save_name)
211 |
212 | if not os.path.exists(save_dir):
213 | os.makedirs(save_dir)
214 | print("{}\n".format(save_dir))
215 |
216 | logging_path = osp.join(save_dir, "info.log")
217 | logging.basicConfig(
218 | level=logging.INFO,
219 | format="%(levelname)s: %(message)s",
220 | filename=logging_path,
221 | filemode="w",
222 | )
223 |
224 | logging.info("====== load info and config =======")
225 | dataset_info = np.load(
226 | osp.join(config.PATH_TO_TEST_LST, args.dataset + ".npy"), allow_pickle=True
227 | ).item()
228 |
229 | for feat_name in ["video", "audio", "text"]:
230 | logging.info(
231 | "Input feature: {} ===> dim is (1, {})".format(
232 | dataset_info["feat_dim"][f"{feat_name}_feat_description"],
233 | dataset_info["feat_dim"][f"{feat_name}_dim"],
234 | )
235 | )
236 |
237 | emo_rule = config.EMO_RULE[args.emo_rule]
238 | emotions = list(emo_rule)
239 |
240 | logging.info("====== load dataset =======")
241 | test_info = dataset_info["test_info"]
242 | test_loader = get_dataloader_valid(
243 | names=test_info["names"],
244 | emo_labels=test_info["emos"],
245 | val_labels=test_info["vals"],
246 | emo_rule=emo_rule,
247 | audio_feat_paths=test_info["audio_feat"],
248 | video_feat_paths=test_info["video_feat"],
249 | text_feat_paths=test_info["text_feat"],
250 | batch_size=args.batch_size,
251 | num_workers=args.num_workers,
252 | shuffle=False,
253 | )
254 |
255 | logging.info("====== build model =======")
256 | model = get_models(args).cuda()
257 | assert args.load_key is not None
258 | model_path = osp.join(config.MODEL_DIR_DICT[args.model][args.load_key], "best.pth")
259 | model.load_state_dict(
260 | torch.load(model_path, map_location=f"cuda:0", weights_only=True)
261 | )
262 |
263 | logging.info("load model: {} \n".format(args.model))
264 | logging.info("load model weight: {} \n".format(model_path))
265 |
266 | logging.info("====== Evaluation =======")
267 | best_eval_metric = None # for select best model weight
268 | record_test = {"emoacc": [], "emofscore": []}
269 |
270 | test_results = test_model(
271 | args,
272 | model,
273 | test_loader,
274 | )
275 |
276 | # np.save(osp.join(save_dir, "results.npy"), test_results)
277 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/feature_extraction/visual/extract_vision_huggingface.py:
--------------------------------------------------------------------------------
1 | # *_*coding:utf-8 *_*
2 | import os
3 | import cv2
4 | import math
5 | import argparse
6 | import numpy as np
7 | from PIL import Image
8 |
9 | import torch
10 | import timm # pip install timm==0.9.7
11 | from transformers import AutoModel, AutoFeatureExtractor, AutoImageProcessor
12 | import time
13 |
14 | # import config
15 | import sys
16 |
17 | current_file_path = os.path.abspath(__file__)
18 | sys.path.append(os.path.dirname(os.path.dirname(current_file_path)))
19 |
20 | import config
21 |
22 | ##################### Pretrained models #####################
23 | CLIP_VIT_BASE = (
24 | "clip-vit-base-patch32" # https://huggingface.co/openai/clip-vit-base-patch32
25 | )
26 | CLIP_VIT_LARGE = (
27 | "clip-vit-large-patch14" # https://huggingface.co/openai/clip-vit-large-patch14
28 | )
29 | EVACLIP_VIT = "eva02_base_patch14_224.mim_in22k" # https://huggingface.co/timm/eva02_base_patch14_224.mim_in22k
30 | DATA2VEC_VISUAL = "data2vec-vision-base-ft1k" # https://huggingface.co/facebook/data2vec-vision-base-ft1k
31 | VIDEOMAE_BASE = "videomae-base" # https://huggingface.co/MCG-NJU/videomae-base
32 | VIDEOMAE_LARGE = "videomae-large" # https://huggingface.co/MCG-NJU/videomae-large
33 | DINO2_LARGE = "dinov2-large" # https://huggingface.co/facebook/dinov2-large
34 | DINO2_GIANT = "dinov2-giant" # https://huggingface.co/facebook/dinov2-giant
35 |
36 |
37 | def func_opencv_to_image(img):
38 | img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
39 | return img
40 |
41 |
42 | def func_opencv_to_numpy(img):
43 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
44 | return img
45 |
46 |
47 | def func_read_frames(face_dir, vid):
48 | npy_path = os.path.join(face_dir, vid, f"{vid}.npy")
49 | assert os.path.exists(npy_path), f"Error: {vid} does not have frames.npy!"
50 | frames = np.load(npy_path)
51 | return frames
52 |
53 |
54 | # 策略2:VideoMAE修订,采用孙总提供的采样代码
55 | def func_get_indexes_for_videomae(total_frames):
56 |
57 | clip_len = 16
58 | test_num_segment = 2
59 | frame_sample_rate = 4
60 |
61 | # sampling according to 'frame_sample_rate'
62 | all_index = [x for x in range(0, total_frames, frame_sample_rate)]
63 | while len(all_index) < clip_len:
64 | all_index.append(all_index[-1])
65 |
66 | # get start end index according to 'segment_id' => 至少可以保证后面有 16 frames
67 | whole_segment_index = []
68 | for segment_id in range(test_num_segment):
69 | temporal_step = max(
70 | 1.0 * (len(all_index) - clip_len) / (test_num_segment - 1), 0
71 | )
72 | temporal_start = int(segment_id * temporal_step)
73 | segment_index = all_index[temporal_start : temporal_start + clip_len]
74 | print(f"segment '{segment_id+1}' index: {segment_index}")
75 | whole_segment_index.append(segment_index)
76 | return whole_segment_index
77 |
78 |
79 | def resample_frames_sunlicai(frames):
80 | batches = []
81 | whole_segment_index = func_get_indexes_for_videomae(len(frames))
82 | for segment_index in whole_segment_index:
83 | assert len(segment_index) == 16
84 | batches.append(np.array(frames)[segment_index])
85 | return batches
86 |
87 |
88 | # 策略3:相比于上面采样更加均匀 [将videomae替换并重新测试]
89 | def resample_frames_uniform(frames, nframe=16):
90 | vlen = len(frames)
91 | start, end = 0, vlen
92 |
93 | n_frms_update = min(nframe, vlen) # for vlen < n_frms, only read vlen
94 | indices = np.arange(start, end, vlen / n_frms_update).astype(int).tolist()
95 |
96 | # whether compress into 'n_frms'
97 | while len(indices) < nframe:
98 | indices.append(indices[-1])
99 | indices = indices[:nframe]
100 | assert len(indices) == nframe, f"{indices}, {vlen}, {nframe}"
101 | return frames[indices]
102 |
103 |
104 | def split_into_batch(inputs, bsize=32):
105 | batches = []
106 | for ii in range(math.ceil(len(inputs) / bsize)):
107 | batch = inputs[ii * bsize : (ii + 1) * bsize]
108 | batches.append(batch)
109 | return batches
110 |
111 |
112 | if __name__ == "__main__":
113 | parser = argparse.ArgumentParser(description="Run.")
114 | parser.add_argument(
115 | "--dataset",
116 | type=str,
117 | default="/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/",
118 | help="input dataset",
119 | )
120 | parser.add_argument(
121 | "--model_name",
122 | type=str,
123 | default="clip-vit-large-patch14",
124 | help="name of pretrained model",
125 | )
126 | parser.add_argument(
127 | "--feature_level",
128 | type=str,
129 | default="UTTERANCE",
130 | help="feature level [FRAME or UTTERANCE]",
131 | )
132 | parser.add_argument(
133 | "--videomae_type",
134 | type=str,
135 | default=None,
136 | help="videomae input type: [None or sunlicai]",
137 | )
138 | parser.add_argument("--gpu", type=int, default=0, help="gpu id")
139 | params = parser.parse_args()
140 |
141 | print(f"==> Extracting {params.model_name} embeddings...")
142 | model_name = params.model_name.split(".")[0]
143 | face_dir = os.path.join(params.dataset, "features", "openface_face")
144 |
145 | start_time = time.time()
146 |
147 | # gain save_dir
148 | if params.videomae_type is None:
149 | save_dir = os.path.join(
150 | params.dataset,
151 | "features",
152 | f"{model_name}-{params.feature_level[:3]}",
153 | )
154 | else:
155 | assert params.videomae_type == "sunlicai"
156 | save_dir = os.path.join(
157 | params.dataset,
158 | "features",
159 | f"sunlicai-{model_name}-{params.feature_level[:3]}",
160 | )
161 | if not os.path.exists(save_dir):
162 | os.makedirs(save_dir)
163 |
164 | # load model
165 | if params.model_name in [
166 | CLIP_VIT_BASE,
167 | CLIP_VIT_LARGE,
168 | DATA2VEC_VISUAL,
169 | VIDEOMAE_BASE,
170 | VIDEOMAE_LARGE,
171 | ]: # from huggingface
172 | model_dir = os.path.join(
173 | config.PATH_TO_PRETRAINED_MODELS, f"transformers/{params.model_name}"
174 | )
175 | model = AutoModel.from_pretrained(model_dir)
176 | processor = AutoFeatureExtractor.from_pretrained(model_dir)
177 | elif params.model_name in [DINO2_LARGE, DINO2_GIANT]:
178 | model_dir = os.path.join(
179 | config.PATH_TO_PRETRAINED_MODELS, f"transformers/{params.model_name}"
180 | )
181 | model = AutoModel.from_pretrained(model_dir)
182 | processor = AutoImageProcessor.from_pretrained(model_dir)
183 | elif params.model_name in [EVACLIP_VIT]: # from timm
184 | model_path = os.path.join(
185 | config.PATH_TO_PRETRAINED_MODELS,
186 | f"timm/{params.model_name}/model.safetensors",
187 | )
188 | model = timm.create_model(
189 | params.model_name,
190 | pretrained=True,
191 | num_classes=0,
192 | pretrained_cfg_overlay=dict(file=model_path),
193 | )
194 | data_config = timm.data.resolve_model_data_config(model)
195 | transforms = timm.data.create_transform(**data_config, is_training=False)
196 |
197 | # 有 gpu 才会放在cuda上
198 | if params.gpu != -1:
199 | torch.cuda.set_device(params.gpu)
200 | model.cuda()
201 | model.eval()
202 |
203 | load_time = time.time()
204 |
205 | # extract embedding video by video
206 | vids = os.listdir(face_dir)
207 | EMBEDDING_DIM = -1
208 | print(f'Find total "{len(vids)}" videos.')
209 | for i, vid in enumerate(vids, 1):
210 | print(f"Processing video '{vid}' ({i}/{len(vids)})...")
211 | # save_file = os.path.join(save_dir, f'{vid}.npy')
212 | # if os.path.exists(save_file): continue
213 |
214 | # forward process [different model has its unique mode, it is hard to unify them as one process]
215 | # => split into batch to reduce memory usage
216 | with torch.no_grad():
217 | frames = func_read_frames(face_dir, vid)
218 | if params.model_name in [CLIP_VIT_BASE, CLIP_VIT_LARGE]:
219 | frames = [func_opencv_to_image(frame) for frame in frames]
220 | inputs = processor(images=frames, return_tensors="pt")["pixel_values"]
221 | if params.gpu != -1:
222 | inputs = inputs.to("cuda")
223 | batches = split_into_batch(inputs, bsize=32)
224 | embeddings = []
225 | for batch in batches:
226 | embeddings.append(model.get_image_features(batch)) # [58, 768]
227 | embeddings = torch.cat(embeddings, axis=0) # [frames_num, 768]
228 |
229 | elif params.model_name in [DATA2VEC_VISUAL]:
230 | frames = [func_opencv_to_image(frame) for frame in frames]
231 | inputs = processor(images=frames, return_tensors="pt")[
232 | "pixel_values"
233 | ] # [nframe, 3, 224, 224]
234 | if params.gpu != -1:
235 | inputs = inputs.to("cuda")
236 | batches = split_into_batch(inputs, bsize=32)
237 | embeddings = []
238 | for batch in batches: # [32, 3, 224, 224]
239 | hidden_states = model(
240 | batch, output_hidden_states=True
241 | ).hidden_states # [58, 196 patch + 1 cls, feat=768]
242 | embeddings.append(
243 | torch.stack(hidden_states)[-1].sum(dim=1)
244 | ) # [58, 768]
245 | embeddings = torch.cat(embeddings, axis=0) # [frames_num, 768]
246 |
247 | elif params.model_name in [DINO2_LARGE, DINO2_GIANT]:
248 | frames = resample_frames_uniform(
249 | frames, nframe=64
250 | ) # 加速特征提起:这种方式更加均匀的采样64帧
251 | frames = [func_opencv_to_image(frame) for frame in frames]
252 | inputs = processor(images=frames, return_tensors="pt")[
253 | "pixel_values"
254 | ] # [nframe, 3, 224, 224]
255 | if params.gpu != -1:
256 | inputs = inputs.to("cuda")
257 | batches = split_into_batch(inputs, bsize=32)
258 | embeddings = []
259 | for batch in batches: # [32, 3, 224, 224]
260 | hidden_states = model(
261 | batch, output_hidden_states=True
262 | ).hidden_states # [58, 196 patch + 1 cls, feat=768]
263 | embeddings.append(
264 | torch.stack(hidden_states)[-1].sum(dim=1)
265 | ) # [58, 768]
266 | embeddings = torch.cat(embeddings, axis=0) # [frames_num, 768]
267 |
268 | elif params.model_name in [VIDEOMAE_BASE, VIDEOMAE_LARGE]:
269 | # videoVAE: only supports 16 frames inputs
270 | if params.videomae_type == "sunlicai":
271 | batches = resample_frames_sunlicai(frames)
272 | else:
273 | batches = [
274 | resample_frames_uniform(frames)
275 | ] # convert to list of batches
276 | embeddings = []
277 | for batch in batches:
278 | frames = [
279 | func_opencv_to_numpy(frame) for frame in batch
280 | ] # 16 * [112, 112, 3]
281 | inputs = processor(list(frames), return_tensors="pt")[
282 | "pixel_values"
283 | ] # [1, 16, 3, 224, 224]
284 | if params.gpu != -1:
285 | inputs = inputs.to("cuda")
286 | outputs = model(inputs).last_hidden_state # [1, 1586, 768]
287 | num_patches_per_frame = (
288 | model.config.image_size // model.config.patch_size
289 | ) ** 2 # 14*14
290 | outputs = outputs.view(
291 | 16 // model.config.tubelet_size, num_patches_per_frame, -1
292 | ) # [seg_number, patch, featdim]
293 | embeddings.append(outputs.mean(dim=1)) # [seg_number, featdim]
294 | embeddings = torch.cat(embeddings, axis=0)
295 |
296 | elif params.model_name in [EVACLIP_VIT]:
297 | frames = [func_opencv_to_image(frame) for frame in frames]
298 | inputs = torch.stack(
299 | [transforms(frame) for frame in frames]
300 | ) # [117, 3, 224, 224]
301 | if params.gpu != -1:
302 | inputs = inputs.to("cuda")
303 | batches = split_into_batch(inputs, bsize=32)
304 | embeddings = []
305 | for batch in batches:
306 | embeddings.append(model(batch)) # [58, 768]
307 | embeddings = torch.cat(embeddings, axis=0) # [frames_num, 768]
308 |
309 | embeddings = embeddings.detach().squeeze().cpu().numpy()
310 | EMBEDDING_DIM = max(EMBEDDING_DIM, np.shape(embeddings)[-1])
311 |
312 | # save into npy
313 | save_file = os.path.join(save_dir, f"{vid}.npy")
314 | if params.feature_level == "FRAME":
315 | embeddings = np.array(embeddings).squeeze()
316 | if len(embeddings) == 0:
317 | embeddings = np.zeros((1, EMBEDDING_DIM))
318 | elif len(embeddings.shape) == 1:
319 | embeddings = embeddings[np.newaxis, :]
320 | np.save(save_file, embeddings)
321 | else:
322 | embeddings = np.array(embeddings).squeeze()
323 | if len(embeddings) == 0:
324 | embeddings = np.zeros((EMBEDDING_DIM,))
325 | elif len(embeddings.shape) == 2:
326 | embeddings = np.mean(embeddings, axis=0)
327 | np.save(save_file, embeddings)
328 |
329 | end_time = time.time()
330 | print(f"Model load time used: {load_time - start_time:.1f}s.")
331 | print(f"Total time used: {end_time - start_time:.1f}s.")
332 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import ast
3 | import datetime
4 | import logging
5 | import os
6 | import os.path as osp
7 | import random
8 |
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import torch
12 | import torch.optim as optim
13 | from omegaconf import OmegaConf
14 | from tqdm import tqdm
15 |
16 | import config
17 | from dataloader.dataloader import get_dataloader_train, get_dataloader_valid
18 | from models import get_models
19 | from toolkit.utils.draw_process import draw_loss, draw_metric
20 | from toolkit.utils.eval import calculate_results
21 | from toolkit.utils.functions import func_update_storage, merge_args_config
22 | from toolkit.utils.loss import CELoss, MSELoss
23 | from toolkit.utils.metric import gain_metric_from_results
24 |
25 |
26 | def set_seed(seed=0):
27 | torch.manual_seed(seed)
28 | torch.cuda.manual_seed(seed)
29 | torch.cuda.manual_seed_all(seed)
30 | np.random.seed(seed)
31 | random.seed(seed)
32 | torch.backends.cudnn.benchmark = False
33 | torch.backends.cudnn.deterministic = True
34 |
35 |
36 | def save_model(model, save_path):
37 | if hasattr(model, "module"):
38 | model_state = model.module.state_dict()
39 | else:
40 | model_state = model.state_dict()
41 |
42 | torch.save(
43 | model_state,
44 | osp.join(save_path),
45 | )
46 | return
47 |
48 |
49 | def train_or_eval_model(
50 | args, model, reg_loss, cls_loss, dataloader, epoch, optimizer=None, train=False
51 | ):
52 |
53 | if not train:
54 | vidnames = []
55 | val_preds, val_labels = [], []
56 | emo_probs, emo_labels = [], []
57 | losses = []
58 |
59 | assert not train or optimizer != None
60 |
61 | if train:
62 | model.train()
63 | else:
64 | model.eval()
65 |
66 | if train:
67 | pbar = tqdm(range(len(dataloader)), desc=f"epoch:{epoch+1}, train")
68 | else:
69 | pbar = tqdm(range(len(dataloader)), desc=f"epoch:{epoch+1}, valid")
70 |
71 | for iter, data in enumerate(dataloader):
72 | if train:
73 | optimizer.zero_grad()
74 |
75 | # read data + cuda
76 | inputs, emos, vals, bnames = data
77 |
78 | if not train:
79 | vidnames += bnames
80 |
81 | for k in inputs.keys():
82 | inputs[k] = inputs[k].cuda()
83 |
84 | emos = emos.long().cuda()
85 | vals = vals.float().cuda()
86 |
87 | # forward process
88 | # start_time = time.time()
89 |
90 | if args.train_input_mode == "input_gt":
91 | features, emos_out, vals_out, interloss = model([inputs, emos])
92 | elif args.train_input_mode == "input":
93 | features, emos_out, vals_out, interloss = model(inputs)
94 | # duration = time.time() - start_time
95 | # macs, params = profile(model, inputs=(batch, ))
96 | # print(f"MACs: {macs}, Parameters: {params}, Duration: {duration}; bsize: {len(bnames)}")
97 |
98 | # loss calculation
99 | loss = interloss
100 |
101 | if args.output_dim1 != 0:
102 | loss = loss + cls_loss(emos_out, emos)
103 | if not train:
104 | emo_probs.append(emos_out.data.cpu().numpy())
105 | emo_labels.append(emos.data.cpu().numpy())
106 | if args.output_dim2 != 0:
107 | loss = loss + reg_loss(vals_out, vals)
108 | if not train:
109 | val_preds.append(vals_out.data.cpu().numpy())
110 | val_labels.append(vals.data.cpu().numpy())
111 | losses.append(loss.data.cpu().numpy())
112 |
113 | pbar.set_postfix(**{"loss": loss.item()})
114 | pbar.update(1)
115 |
116 | # optimize params
117 | if train:
118 | loss.backward()
119 | if model.model.grad_clip != -1:
120 | torch.nn.utils.clip_grad_value_(
121 | [param for param in model.parameters() if param.requires_grad],
122 | model.model.grad_clip,
123 | )
124 | optimizer.step()
125 |
126 | pbar.close()
127 |
128 | if not train:
129 | if emo_probs != []:
130 | emo_probs = np.concatenate(emo_probs)
131 | if emo_labels != []:
132 | emo_labels = np.concatenate(emo_labels)
133 | if val_preds != []:
134 | val_preds = np.concatenate(val_preds)
135 | if val_labels != []:
136 | val_labels = np.concatenate(val_labels)
137 | results, _ = calculate_results(emo_probs, emo_labels, val_preds, val_labels)
138 | save_results = dict(
139 | names=vidnames,
140 | loss=np.mean(losses),
141 | **results,
142 | )
143 | else:
144 | save_results = dict(
145 | loss=np.mean(losses),
146 | )
147 | return save_results
148 |
149 |
150 | if __name__ == "__main__":
151 | parser = argparse.ArgumentParser()
152 | parser.add_argument(
153 | "--seed",
154 | type=int,
155 | default=0,
156 | help="random seed",
157 | )
158 |
159 | # --- data config
160 | parser.add_argument(
161 | "--dataset",
162 | type=str,
163 | default="seed1",
164 | help="dataset info name",
165 | )
166 | parser.add_argument(
167 | "--emo_rule",
168 | type=str,
169 | default="MER",
170 | help="emo map function from emotion to index",
171 | )
172 |
173 | # --- save config
174 | parser.add_argument(
175 | "--save_root",
176 | type=str,
177 | default="./saved",
178 | help="save prediction results and models",
179 | )
180 | parser.add_argument(
181 | "--save_as_time",
182 | action="store_true",
183 | default=False,
184 | help="save suffix as time, default: run",
185 | )
186 | parser.add_argument(
187 | "--save_model",
188 | action="store_true",
189 | default=False,
190 | help="whether to save model, default: False",
191 | )
192 |
193 | # --- model config
194 | parser.add_argument(
195 | "--model",
196 | type=str,
197 | default="auto_attention",
198 | help="model name for training ",
199 | )
200 | parser.add_argument(
201 | "--model_pretrain",
202 | type=str,
203 | default=None,
204 | help="pretrained model path",
205 | )
206 |
207 | # --- feat config
208 | parser.add_argument(
209 | "--feat",
210 | type=str,
211 | default='["senet50face_UTT","senet50face_UTT","senet50face_UTT"]',
212 | help="use feat",
213 | )
214 |
215 | # --- train config
216 | parser.add_argument(
217 | "--lr", type=float, default=1e-4, metavar="lr", help="set lr rate"
218 | )
219 | # parser.add_argument(
220 | # "--lr_end", type=float, default=1e-5, metavar="lr", help="set lr rate"
221 | # )
222 | parser.add_argument(
223 | "--l2",
224 | type=float,
225 | default=0.00001,
226 | metavar="L2",
227 | help="L2 regularization weight",
228 | )
229 | parser.add_argument(
230 | "--batch_size",
231 | type=int,
232 | default=512,
233 | metavar="BS",
234 | help="batch size [deal with OOM]",
235 | )
236 | parser.add_argument(
237 | "--num_workers", type=int, default=4, metavar="nw", help="number of workers"
238 | )
239 | parser.add_argument(
240 | "--epochs", type=int, default=60, metavar="E", help="number of epochs"
241 | )
242 | parser.add_argument("--gpu", default=0, type=int, help="GPU id to use")
243 |
244 | args = parser.parse_args()
245 | args.feat = ast.literal_eval(args.feat)
246 |
247 | set_seed(args.seed)
248 |
249 | torch.cuda.set_device(args.gpu)
250 |
251 | print("====== set save dir =======")
252 | now = datetime.datetime.now()
253 | if args.save_as_time:
254 | save_dir = os.path.join(
255 | args.save_root, args.dataset, args.model, now.strftime("%Y-%m-%d-%H-%M-%S")
256 | )
257 | else:
258 | save_dir = os.path.join(args.save_root, args.dataset, args.model, "run")
259 | if not os.path.exists(save_dir):
260 | os.makedirs(save_dir)
261 | print("{}\n".format(save_dir))
262 |
263 | logging_path = osp.join(save_dir, "info.log")
264 | logging.basicConfig(
265 | level=logging.INFO,
266 | format="%(levelname)s: %(message)s",
267 | filename=logging_path,
268 | filemode="w",
269 | )
270 |
271 | logging.info("====== load info and config =======")
272 | dataset_info = np.load(
273 | osp.join(config.PATH_TO_TRAIN_LST, args.dataset + ".npy"), allow_pickle=True
274 | ).item()
275 |
276 | model_config = OmegaConf.load("models/model-tune.yaml")[args.model]
277 | model_config = OmegaConf.to_container(model_config, resolve=True)
278 |
279 | emo_rule = config.EMO_RULE[args.emo_rule]
280 | model_config["emo_rule"] = emo_rule
281 | model_config["output_dim1"] = len(emo_rule)
282 | model_config["output_dim2"] = 0
283 |
284 | feat_dims = []
285 | feat_roots = []
286 | feat_types = []
287 | for use_feat in args.feat:
288 | if use_feat in config.FEAT_VIDEO_DICT:
289 | feat_dims.append(config.FEAT_VIDEO_DICT[use_feat][0])
290 | feat_roots.append(config.FEAT_VIDEO_DICT[use_feat][1])
291 | feat_types.append("V")
292 | elif use_feat in config.FEAT_AUDIO_DICT:
293 | feat_dims.append(config.FEAT_AUDIO_DICT[use_feat][0])
294 | feat_roots.append(config.FEAT_AUDIO_DICT[use_feat][1])
295 | feat_types.append("A")
296 | elif use_feat in config.FEAT_TEXT_DICT:
297 | feat_dims.append(config.FEAT_TEXT_DICT[use_feat][0])
298 | feat_roots.append(config.FEAT_TEXT_DICT[use_feat][1])
299 | feat_types.append("T")
300 | model_config["feat_dims"] = feat_dims
301 | model_config["feat_roots"] = feat_roots
302 | for feat_type, feat_dim, feat_root in zip(feat_types, feat_dims, feat_roots):
303 | logging.info("Modality:{}, Input feature: {}".format(feat_type, feat_root))
304 | logging.info("===> dim is (1, {}) \n".format(feat_dim))
305 |
306 | args = merge_args_config(args, model_config) # 两部分参数重叠
307 | logging.info("save config: {} \n".format(osp.join(save_dir, "args.yaml")))
308 | OmegaConf.save(vars(args), osp.join(save_dir, "args.yaml"))
309 |
310 | logging.info("====== load dataset =======")
311 | train_info = dataset_info["train_info"]
312 | train_loader = get_dataloader_train(
313 | names=train_info["names"],
314 | emo_labels=train_info["emos"],
315 | val_labels=train_info["vals"],
316 | emo_rule=emo_rule,
317 | feat_roots=args.feat_roots,
318 | batch_size=args.batch_size,
319 | num_workers=args.num_workers,
320 | shuffle=True,
321 | )
322 | valid_info = dataset_info["valid_info"]
323 | valid_loader = get_dataloader_valid(
324 | names=valid_info["names"],
325 | emo_labels=valid_info["emos"],
326 | val_labels=valid_info["vals"],
327 | emo_rule=emo_rule,
328 | feat_roots=args.feat_roots,
329 | batch_size=args.batch_size,
330 | num_workers=args.num_workers,
331 | shuffle=False,
332 | )
333 |
334 | logging.info("====== build model =======")
335 | model = get_models(args).cuda()
336 | cls_loss = CELoss().cuda()
337 | reg_loss = MSELoss().cuda()
338 | optimizer = optim.Adam(
339 | (param for param in model.parameters() if param.requires_grad),
340 | lr=args.lr,
341 | weight_decay=args.l2,
342 | )
343 | logging.info("load model: {} \n".format(args.model))
344 |
345 | if args.model_pretrain is not None:
346 | model.load_state_dict(torch.load(args.model_pretrain, map_location=f"cuda:0"))
347 |
348 | logging.info("====== Training and Evaluation =======")
349 | best_eval_metric = None # for select best model weight
350 | record_train = {"epoch": [], "loss": []}
351 | record_valid = {"epoch": [], "loss": [], "emoacc": [], "emofscore": []}
352 | best_valid = {}
353 | for epoch in range(args.epochs):
354 | logging.info(
355 | "epoch: {}, lr:{:.8f}".format(
356 | epoch + 1, optimizer.state_dict()["param_groups"][0]["lr"]
357 | )
358 | )
359 |
360 | train_results = train_or_eval_model(
361 | args,
362 | model,
363 | reg_loss,
364 | cls_loss,
365 | train_loader,
366 | epoch=epoch,
367 | optimizer=optimizer,
368 | train=True,
369 | )
370 | valid_results = train_or_eval_model(
371 | args,
372 | model,
373 | reg_loss,
374 | cls_loss,
375 | valid_loader,
376 | epoch=epoch,
377 | optimizer=None,
378 | train=False,
379 | )
380 |
381 | # --- logging info
382 | logging_info = "\tTRAIN: loss:{:.4f}".format(train_results["loss"])
383 | logging.info(logging_info)
384 |
385 | logging_info = "\tVALID: loss:{:.4f}, acc:{:.4f}, f-score:{:.4f}".format(
386 | valid_results["loss"],
387 | valid_results["emoacc"],
388 | valid_results["emofscore"],
389 | )
390 | logging.info(logging_info)
391 |
392 | # --- record data and plot
393 | record_train["epoch"].append(epoch + 1)
394 | record_train["loss"].append(train_results["loss"])
395 |
396 | record_valid["epoch"].append(epoch + 1)
397 | for key in record_valid:
398 | if key == "epoch":
399 | continue
400 | record_valid[key].append(valid_results[key])
401 |
402 | # --- save best model
403 | ## select metric
404 | eval_metric = valid_results["emofscore"]
405 |
406 | # save eval results, model config and weight
407 | if best_eval_metric is None or best_eval_metric < eval_metric:
408 | best_eval_metric = eval_metric
409 | # save best result info in valid dataset
410 | best_valid["epoch"] = epoch + 1
411 | for key in record_valid:
412 | if key in ["epoch"]:
413 | continue
414 | best_valid[key] = record_valid[key][-1]
415 |
416 | # save best result in valid dataset
417 | epoch_store = {}
418 | func_update_storage(
419 | inputs=valid_results, prefix="eval", outputs=epoch_store
420 | )
421 | np.save(osp.join(save_dir, "best_valid_results.npy"), epoch_store)
422 |
423 | # save best model weight and args
424 | np.save(osp.join(save_dir, "best_args.npy"), {"args": args})
425 | if args.save_model:
426 | save_path = f"{save_dir}/best.pth"
427 | save_model(model, save_path)
428 |
429 | logging.info("\t*** Update best info ! ***")
430 |
431 | draw_loss(
432 | record_train["epoch"],
433 | record_train["loss"],
434 | record_valid["loss"],
435 | osp.join(save_dir, "loss.png"),
436 | )
437 | for key in record_valid:
438 | if key in ["epoch", "loss"]:
439 | continue
440 | draw_metric(
441 | record_valid["epoch"],
442 | record_valid[key],
443 | key,
444 | osp.join(save_dir, "{}.png".format(key)),
445 | )
446 |
447 | logging.info("End Training ! \n\n")
448 | logging_info = "BEST: epoch:{:.0f}, loss:{:.4f}, acc:{:.4f}, f-score:{:.4f}".format(
449 | best_valid["epoch"],
450 | best_valid["loss"],
451 | best_valid["emoacc"],
452 | best_valid["emofscore"],
453 | )
454 | logging.info(logging_info)
455 |
456 | # clear memory
457 | del model
458 | del optimizer
459 | torch.cuda.empty_cache()
460 |
--------------------------------------------------------------------------------
/models/emo_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Encoder(nn.Module):
6 |
7 | def __init__(
8 | self, inp_dim, out_dim, dropout, mid_dim=None, num_heads=4, num_layers=2
9 | ):
10 |
11 | super(Encoder, self).__init__()
12 | if isinstance(inp_dim, list):
13 | self.mode = len(inp_dim)
14 | else:
15 | self.mode = 1
16 |
17 | if mid_dim is None:
18 | mid_dim = out_dim
19 | if self.mode == 1:
20 | # self.norm = nn.BatchNorm1d(in_size)
21 | self.linear = nn.Linear(inp_dim, mid_dim)
22 | self.net = nn.TransformerEncoder(
23 | nn.TransformerEncoderLayer(
24 | d_model=mid_dim,
25 | nhead=num_heads,
26 | dim_feedforward=2048,
27 | dropout=dropout,
28 | batch_first=True,
29 | ),
30 | num_layers=num_layers,
31 | )
32 | self.linear2 = nn.Linear(mid_dim, out_dim)
33 | elif self.mode == 2:
34 | self.linear = nn.Linear(35, mid_dim)
35 | self.cls_token = nn.Parameter(torch.zeros(1, 1, mid_dim))
36 | self.net = nn.TransformerEncoder(
37 | nn.TransformerEncoderLayer(
38 | d_model=mid_dim,
39 | nhead=num_heads,
40 | dim_feedforward=2048,
41 | dropout=dropout,
42 | batch_first=True,
43 | ),
44 | num_layers=num_layers,
45 | )
46 | self.linear2 = nn.Linear(mid_dim, out_dim)
47 |
48 | def forward(self, x):
49 | """
50 | Args:
51 | x: tensor of shape (batch_size, in_size)
52 | """
53 | # normed = self.norm(x)
54 |
55 | if self.mode == 1:
56 | x = self.linear(x)
57 | x = self.net(x)
58 | elif self.mode == 2:
59 | x = x[:, :, 674:] # only AU feature in openface features
60 | x = self.linear(x)
61 | B, L, C = x.shape
62 | cls_token = self.cls_token.expand(B, -1, -1) # (B, 1, mid_dim)
63 | x = torch.cat((cls_token, x), dim=1) # (B, L+1, mid_dim)
64 | x = self.net(x) # (B, L+1, mid_dim)
65 | x = x[:, 0, :]
66 |
67 | x = self.linear2(x)
68 |
69 | return x
70 |
71 |
72 | class FcClassifier(nn.Module):
73 | def __init__(self, hidden_dim, cls_layers, output_dim=6, dropout=0, use_bn=True):
74 | super(FcClassifier, self).__init__()
75 | self.fc_layers = nn.ModuleList()
76 | self.fc_layers.append(nn.Linear(hidden_dim, cls_layers[0]))
77 | if use_bn is True:
78 | self.fc_layers.append(nn.BatchNorm1d(cls_layers[0]))
79 | self.fc_layers.append(nn.ReLU(inplace=True))
80 | if dropout > 0:
81 | self.fc_layers.append(nn.Dropout(dropout))
82 |
83 | for i in range(1, len(cls_layers)):
84 | self.fc_layers.append(nn.Linear(cls_layers[i - 1], cls_layers[i]))
85 | if use_bn is True:
86 | self.fc_layers.append(nn.BatchNorm1d(cls_layers[i]))
87 | self.fc_layers.append(nn.ReLU(inplace=True))
88 | if dropout > 0:
89 | self.fc_layers.append(nn.Dropout(dropout))
90 |
91 | self.output_layer = nn.Linear(cls_layers[-1], output_dim)
92 |
93 | def forward(self, x):
94 | for layer in self.fc_layers:
95 | x = layer(x)
96 | x = self.output_layer(x)
97 | return x
98 |
99 |
100 | class Auto_WeightV1(nn.Module):
101 | def __init__(self, args):
102 | super(Auto_WeightV1, self).__init__()
103 | feat_dims = args.feat_dims
104 | output_dim1 = args.output_dim1
105 | output_dim2 = args.output_dim2
106 | dropout = args.dropout
107 | hidden_dim = args.hidden_dim
108 | self.grad_clip = args.grad_clip
109 | self.feat_dims = feat_dims
110 |
111 | num_heads = args.num_heads
112 | num_layers = args.num_layers
113 |
114 | assert args.feat_type == "utt"
115 |
116 | assert len(feat_dims) <= 3 * 6
117 | if len(feat_dims) >= 1:
118 | self.encoder0 = Encoder(
119 | feat_dims[0],
120 | hidden_dim,
121 | dropout,
122 | num_heads=num_heads,
123 | num_layers=num_layers,
124 | )
125 | self.cls0 = FcClassifier(
126 | hidden_dim,
127 | [256, 128],
128 | output_dim=output_dim1,
129 | dropout=0,
130 | use_bn=True,
131 | )
132 | if len(feat_dims) >= 2:
133 | self.encoder1 = Encoder(
134 | feat_dims[1],
135 | hidden_dim,
136 | dropout,
137 | num_heads=num_heads,
138 | num_layers=num_layers,
139 | )
140 | self.cls1 = FcClassifier(
141 | hidden_dim,
142 | [256, 128],
143 | output_dim=output_dim1,
144 | dropout=0,
145 | use_bn=True,
146 | )
147 | if len(feat_dims) >= 3:
148 | self.encoder2 = Encoder(
149 | feat_dims[2],
150 | hidden_dim,
151 | dropout,
152 | num_heads=num_heads,
153 | num_layers=num_layers,
154 | )
155 | self.cls2 = FcClassifier(
156 | hidden_dim,
157 | [256, 128],
158 | output_dim=output_dim1,
159 | dropout=0,
160 | use_bn=True,
161 | )
162 | if len(feat_dims) >= 4:
163 | self.encoder3 = Encoder(
164 | feat_dims[3],
165 | hidden_dim,
166 | dropout,
167 | num_heads=num_heads,
168 | num_layers=num_layers,
169 | )
170 | self.cls3 = FcClassifier(
171 | hidden_dim,
172 | [256, 128],
173 | output_dim=output_dim1,
174 | dropout=0,
175 | use_bn=True,
176 | )
177 | if len(feat_dims) >= 5:
178 | self.encoder4 = Encoder(
179 | feat_dims[4],
180 | hidden_dim,
181 | dropout,
182 | num_heads=num_heads,
183 | num_layers=num_layers,
184 | )
185 | self.cls4 = FcClassifier(
186 | hidden_dim,
187 | [256, 128],
188 | output_dim=output_dim1,
189 | dropout=0,
190 | use_bn=True,
191 | )
192 | if len(feat_dims) >= 6:
193 | self.encoder5 = Encoder(
194 | feat_dims[5],
195 | hidden_dim,
196 | dropout,
197 | num_heads=num_heads,
198 | num_layers=num_layers,
199 | )
200 | self.cls5 = FcClassifier(
201 | hidden_dim,
202 | [256, 128],
203 | output_dim=output_dim1,
204 | dropout=0,
205 | use_bn=True,
206 | )
207 | if len(feat_dims) >= 7:
208 | self.encoder6 = Encoder(
209 | feat_dims[6],
210 | hidden_dim,
211 | dropout,
212 | num_heads=num_heads,
213 | num_layers=num_layers,
214 | )
215 | self.cls6 = FcClassifier(
216 | hidden_dim,
217 | [256, 128],
218 | output_dim=output_dim1,
219 | dropout=0,
220 | use_bn=True,
221 | )
222 | if len(feat_dims) >= 8:
223 | self.encoder7 = Encoder(
224 | feat_dims[7],
225 | hidden_dim,
226 | dropout,
227 | num_heads=num_heads,
228 | num_layers=num_layers,
229 | )
230 | self.cls7 = FcClassifier(
231 | hidden_dim,
232 | [256, 128],
233 | output_dim=output_dim1,
234 | dropout=0,
235 | use_bn=True,
236 | )
237 | if len(feat_dims) >= 9:
238 | self.encoder8 = Encoder(
239 | feat_dims[8],
240 | hidden_dim,
241 | dropout,
242 | num_heads=num_heads,
243 | num_layers=num_layers,
244 | )
245 | self.cls8 = FcClassifier(
246 | hidden_dim,
247 | [256, 128],
248 | output_dim=output_dim1,
249 | dropout=0,
250 | use_bn=True,
251 | )
252 | if len(feat_dims) >= 10:
253 | self.encoder9 = Encoder(
254 | feat_dims[9],
255 | hidden_dim,
256 | dropout,
257 | num_heads=num_heads,
258 | num_layers=num_layers,
259 | )
260 | self.cls9 = FcClassifier(
261 | hidden_dim,
262 | [256, 128],
263 | output_dim=output_dim1,
264 | dropout=0,
265 | use_bn=True,
266 | )
267 | if len(feat_dims) >= 11:
268 | self.encoder10 = Encoder(
269 | feat_dims[10],
270 | hidden_dim,
271 | dropout,
272 | num_heads=num_heads,
273 | num_layers=num_layers,
274 | )
275 | self.cls10 = FcClassifier(
276 | hidden_dim,
277 | [256, 128],
278 | output_dim=output_dim1,
279 | dropout=0,
280 | use_bn=True,
281 | )
282 | if len(feat_dims) >= 12:
283 | self.encoder11 = Encoder(
284 | feat_dims[11],
285 | hidden_dim,
286 | dropout,
287 | num_heads=num_heads,
288 | num_layers=num_layers,
289 | )
290 | self.cls11 = FcClassifier(
291 | hidden_dim,
292 | [256, 128],
293 | output_dim=output_dim1,
294 | dropout=0,
295 | use_bn=True,
296 | )
297 | if len(feat_dims) >= 13:
298 | self.encoder12 = Encoder(
299 | feat_dims[12],
300 | hidden_dim,
301 | dropout,
302 | num_heads=num_heads,
303 | num_layers=num_layers,
304 | )
305 | self.cls12 = FcClassifier(
306 | hidden_dim,
307 | [256, 128],
308 | output_dim=output_dim1,
309 | dropout=0,
310 | use_bn=True,
311 | )
312 | if len(feat_dims) >= 14:
313 | self.encoder13 = Encoder(
314 | feat_dims[13],
315 | hidden_dim,
316 | dropout,
317 | num_heads=num_heads,
318 | num_layers=num_layers,
319 | )
320 | self.cls13 = FcClassifier(
321 | hidden_dim,
322 | [256, 128],
323 | output_dim=output_dim1,
324 | dropout=0,
325 | use_bn=True,
326 | )
327 | if len(feat_dims) >= 15:
328 | self.encoder14 = Encoder(
329 | feat_dims[14],
330 | hidden_dim,
331 | dropout,
332 | num_heads=num_heads,
333 | num_layers=num_layers,
334 | )
335 | self.cls14 = FcClassifier(
336 | hidden_dim,
337 | [256, 128],
338 | output_dim=output_dim1,
339 | dropout=0,
340 | use_bn=True,
341 | )
342 |
343 | # --- feature attention
344 | self.attention_mlp = Encoder(
345 | hidden_dim * len(feat_dims),
346 | hidden_dim,
347 | dropout,
348 | num_heads=num_heads,
349 | num_layers=num_layers,
350 | )
351 | self.fc_att = nn.Linear(hidden_dim, len(feat_dims))
352 |
353 | # --- logit weight
354 | self.weight_fc = nn.Sequential(
355 | nn.Linear(hidden_dim * len(feat_dims), hidden_dim),
356 | nn.GELU(),
357 | nn.Linear(hidden_dim, len(feat_dims) + 1),
358 | )
359 |
360 | self.fc_out_1 = FcClassifier(
361 | hidden_dim * len(feat_dims),
362 | [256, 128],
363 | output_dim=output_dim1,
364 | dropout=0,
365 | use_bn=True,
366 | )
367 | self.fc_out_2 = FcClassifier(
368 | hidden_dim * len(feat_dims),
369 | [256, 128],
370 | output_dim=output_dim2,
371 | dropout=0,
372 | use_bn=True,
373 | )
374 |
375 | self.criterion = nn.CrossEntropyLoss()
376 |
377 | def forward(self, inp):
378 | """
379 | support feat_type: utt | frm-align | frm-unalign
380 | """
381 |
382 | [batch, emos] = inp
383 |
384 | hiddens = []
385 | logits = []
386 | if len(self.feat_dims) >= 1:
387 | hiddens.append(self.encoder0(batch[f"feat0"]))
388 | logits.append(self.cls0(hiddens[0]))
389 | if len(self.feat_dims) >= 2:
390 | hiddens.append(self.encoder1(batch[f"feat1"]))
391 | logits.append(self.cls1(hiddens[1]))
392 | if len(self.feat_dims) >= 3:
393 | hiddens.append(self.encoder2(batch[f"feat2"]))
394 | logits.append(self.cls2(hiddens[2]))
395 | if len(self.feat_dims) >= 4:
396 | hiddens.append(self.encoder3(batch[f"feat3"]))
397 | logits.append(self.cls3(hiddens[3]))
398 | if len(self.feat_dims) >= 5:
399 | hiddens.append(self.encoder4(batch[f"feat4"]))
400 | logits.append(self.cls4(hiddens[4]))
401 | if len(self.feat_dims) >= 6:
402 | hiddens.append(self.encoder5(batch[f"feat5"]))
403 | logits.append(self.cls5(hiddens[5]))
404 | if len(self.feat_dims) >= 7:
405 | hiddens.append(self.encoder6(batch[f"feat6"]))
406 | logits.append(self.cls6(hiddens[6]))
407 | if len(self.feat_dims) >= 8:
408 | hiddens.append(self.encoder7(batch[f"feat7"]))
409 | logits.append(self.cls7(hiddens[7]))
410 | if len(self.feat_dims) >= 9:
411 | hiddens.append(self.encoder8(batch[f"feat8"]))
412 | logits.append(self.cls8(hiddens[8]))
413 | if len(self.feat_dims) >= 10:
414 | hiddens.append(self.encoder9(batch[f"feat9"]))
415 | logits.append(self.cls9(hiddens[9]))
416 | if len(self.feat_dims) >= 11:
417 | hiddens.append(self.encoder10(batch[f"feat10"]))
418 | logits.append(self.cls10(hiddens[10]))
419 | if len(self.feat_dims) >= 12:
420 | hiddens.append(self.encoder11(batch[f"feat11"]))
421 | logits.append(self.cls11(hiddens[11]))
422 | if len(self.feat_dims) >= 13:
423 | hiddens.append(self.encoder12(batch[f"feat12"]))
424 | logits.append(self.cls12(hiddens[12]))
425 | if len(self.feat_dims) >= 14:
426 | hiddens.append(self.encoder13(batch[f"feat13"]))
427 | logits.append(self.cls13(hiddens[13]))
428 | if len(self.feat_dims) >= 15:
429 | hiddens.append(self.encoder14(batch[f"feat14"]))
430 | logits.append(self.cls14(hiddens[14]))
431 |
432 | multi_hidden1 = torch.cat(hiddens, dim=1) # [32, 384]
433 |
434 | logits.append(self.fc_out_1(multi_hidden1))
435 | vals_out = self.fc_out_2(multi_hidden1)
436 |
437 | weight = self.weight_fc(multi_hidden1).unsqueeze(-1)
438 |
439 | multi_logits = torch.stack(logits, dim=2).permute(0, 2, 1) # [32, N, 6]
440 | emos_out = (weight * multi_logits).sum(dim=1)
441 |
442 | if self.training:
443 | interloss = self.cal_loss(
444 | logits,
445 | emos,
446 | )
447 | else:
448 | interloss = torch.tensor(0).cuda()
449 |
450 | return None, emos_out, vals_out, interloss
451 |
452 | def cal_loss(
453 | self,
454 | logits,
455 | emos,
456 | ):
457 | # feat_A_concat, feat_V_concat, feat_L_concat):
458 | emos = emos.to(logits[0].device)
459 |
460 | loss = 0
461 | for logit in logits:
462 | loss += self.criterion(logit, emos) / len(logits)
463 |
464 | return loss
465 |
--------------------------------------------------------------------------------
/feature_extraction/text/extract_text_huggingface.py:
--------------------------------------------------------------------------------
1 | # *_*coding:utf-8 *_*
2 | import os
3 | import time
4 | import argparse
5 | import numpy as np
6 | import pandas as pd
7 |
8 | import torch
9 | from transformers import (
10 | AutoModel,
11 | BertTokenizer,
12 | AutoTokenizer,
13 | ) # version: 4.5.1, pip install transformers
14 | from transformers import GPT2Tokenizer, GPT2Model, AutoModelForCausalLM
15 |
16 | # local folder
17 | import sys
18 |
19 | current_file_path = os.path.abspath(__file__)
20 | sys.path.append(os.path.dirname(os.path.dirname(current_file_path)))
21 | import config
22 |
23 | ##################### English #####################
24 | BERT_BASE = "bert-base-cased"
25 | BERT_LARGE = "bert-large-cased"
26 | BERT_BASE_UNCASED = "bert-base-uncased"
27 | BERT_LARGE_UNCASED = "bert-large-uncased"
28 | ALBERT_BASE = "albert-base-v2"
29 | ALBERT_LARGE = "albert-large-v2"
30 | ALBERT_XXLARGE = "albert-xxlarge-v2"
31 | ROBERTA_BASE = "roberta-base"
32 | ROBERTA_LARGE = "roberta-large"
33 | ELECTRA_BASE = "electra-base-discriminator"
34 | ELECTRA_LARGE = "electra-large-discriminator"
35 | XLNET_BASE = "xlnet-base-cased"
36 | XLNET_LARGE = "xlnet-large-cased"
37 | T5_BASE = "t5-base"
38 | T5_LARGE = "t5-large"
39 | DEBERTA_BASE = "deberta-base"
40 | DEBERTA_LARGE = "deberta-large"
41 | DEBERTA_XLARGE = "deberta-v2-xlarge"
42 | DEBERTA_XXLARGE = "deberta-v2-xxlarge"
43 |
44 | ##################### Chinese #####################
45 | BERT_BASE_CHINESE = "bert-base-chinese" # https://huggingface.co/bert-base-chinese
46 | ROBERTA_BASE_CHINESE = (
47 | "chinese-roberta-wwm-ext" # https://huggingface.co/hfl/chinese-roberta-wwm-ext
48 | )
49 | ROBERTA_LARGE_CHINESE = "chinese-roberta-wwm-ext-large" # https://huggingface.co/hfl/chinese-roberta-wwm-ext-large
50 | DEBERTA_LARGE_CHINESE = (
51 | "deberta-chinese-large" # https://huggingface.co/WENGSYX/Deberta-Chinese-Large
52 | )
53 | ELECTRA_SMALL_CHINESE = "chinese-electra-180g-small" # https://huggingface.co/hfl/chinese-electra-180g-small-discriminator
54 | ELECTRA_BASE_CHINESE = "chinese-electra-180g-base" # https://huggingface.co/hfl/chinese-electra-180g-base-discriminator
55 | ELECTRA_LARGE_CHINESE = "chinese-electra-180g-large" # https://huggingface.co/hfl/chinese-electra-180g-large-discriminator
56 | XLNET_BASE_CHINESE = (
57 | "chinese-xlnet-base" # https://huggingface.co/hfl/chinese-xlnet-base
58 | )
59 | MACBERT_BASE_CHINESE = (
60 | "chinese-macbert-base" # https://huggingface.co/hfl/chinese-macbert-base
61 | )
62 | MACBERT_LARGE_CHINESE = (
63 | "chinese-macbert-large" # https://huggingface.co/hfl/chinese-macbert-large
64 | )
65 | PERT_BASE_CHINESE = "chinese-pert-base" # https://huggingface.co/hfl/chinese-pert-base
66 | PERT_LARGE_CHINESE = (
67 | "chinese-pert-large" # https://huggingface.co/hfl/chinese-pert-large
68 | )
69 | LERT_SMALL_CHINESE = (
70 | "chinese-lert-small" # https://huggingface.co/hfl/chinese-lert-small
71 | )
72 | LERT_BASE_CHINESE = "chinese-lert-base" # https://huggingface.co/hfl/chinese-lert-base
73 | LERT_LARGE_CHINESE = (
74 | "chinese-lert-large" # https://huggingface.co/hfl/chinese-lert-large
75 | )
76 | GPT2_CHINESE = "gpt2-chinese-cluecorpussmall" # https://huggingface.co/uer/gpt2-chinese-cluecorpussmall
77 | CLIP_CHINESE = "taiyi-clip-roberta-chinese" # https://huggingface.co/IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese
78 | WENZHONG_GPT2_CHINESE = "wenzhong2-gpt2-chinese" # https://huggingface.co/IDEA-CCNL/Wenzhong2.0-GPT2-3.5B-chinese
79 | ALBERT_TINY_CHINESE = (
80 | "albert_chinese_tiny" # https://huggingface.co/clue/albert_chinese_tiny
81 | )
82 | ALBERT_SMALL_CHINESE = (
83 | "albert_chinese_small" # https://huggingface.co/clue/albert_chinese_small
84 | )
85 | SIMBERT_BASE_CHINESE = (
86 | "simbert-base-chinese" # https://huggingface.co/WangZeJun/simbert-base-chinese
87 | )
88 |
89 | ##################### Multilingual #####################
90 | MPNET_BASE = "paraphrase-multilingual-mpnet-base-v2" # https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2
91 |
92 | ##################### LLM #####################
93 | LLAMA_7B = "llama-7b-hf" # https://huggingface.co/decapoda-research/llama-7b-hf
94 | LLAMA_13B = "llama-13b-hf" # https://huggingface.co/decapoda-research/llama-13b-hf
95 | LLAMA2_7B = "llama-2-7b" # https://huggingface.co/meta-llama/Llama-2-7b
96 | LLAMA2_13B = "Llama-2-13b-hf" # https://huggingface.co/NousResearch/Llama-2-13b-hf
97 | VICUNA_7B = "vicuna-7b-v0" # https://huggingface.co/lmsys/vicuna-7b-delta-v0
98 | VICUNA_13B = (
99 | "stable-vicuna-13b" # https://huggingface.co/CarperAI/stable-vicuna-13b-delta
100 | )
101 | ALPACE_13B = (
102 | "chinese-alpaca-2-13b" # https://huggingface.co/ziqingyang/chinese-alpaca-2-13b
103 | )
104 | MOSS_7B = "moss-base-7b" # https://huggingface.co/fnlp/moss-base-7b
105 | STABLEML_7B = "stablelm-base-alpha-7b-v2" # https://huggingface.co/stabilityai/stablelm-base-alpha-7b-v2
106 | BLOOM_7B = "bloom-7b1" # https://huggingface.co/bigscience/bloom-7b1
107 | CHATGLM2_6B = "chatglm2-6b" # https://huggingface.co/THUDM/chatglm2-6b
108 | # reley on pytorch=2.0 => env: videollama4 + cpu
109 | FALCON_7B = "falcon-7b" # https://huggingface.co/tiiuae/falcon-7b
110 | # Baichuan: pip install transformers_stream_generator
111 | BAICHUAN_7B = "Baichuan-7B" # https://huggingface.co/baichuan-inc/Baichuan-7B
112 | BAICHUAN_13B = (
113 | "Baichuan-13B-Base" # https://huggingface.co/baichuan-inc/Baichuan-13B-Base
114 | )
115 | # BAICHUAN2_7B: conda install xformers -c xformers
116 | BAICHUAN2_7B = (
117 | "Baichuan2-7B-Base" # https://huggingface.co/baichuan-inc/Baichuan2-7B-Base
118 | )
119 | # BAICHUAN2_13B: pip install accelerate
120 | BAICHUAN2_13B = (
121 | "Baichuan2-13B-Base" # https://huggingface.co/baichuan-inc/Baichuan2-13B-Base
122 | )
123 | OPT_13B = "opt-13b" # https://huggingface.co/facebook/opt-13b
124 |
125 |
126 | ################################################################
127 | # 自动删除无意义token对应的特征
128 | def find_start_end_pos(tokenizer):
129 | sentence = "今天天气真好" # 句子中没有空格
130 | input_ids = tokenizer(sentence, return_tensors="pt")["input_ids"][0]
131 | start, end = None, None
132 |
133 | # find start, must in range [0, 1, 2]
134 | for start in range(0, 3, 1):
135 | # 因为decode有时会出现空格,因此我们显示的时候把这部分信息去掉看看
136 | outputs = tokenizer.decode(input_ids[start:]).replace(" ", "")
137 | if outputs == sentence:
138 | print(f"start: {start}; end: {end}")
139 | return start, None
140 |
141 | if outputs.startswith(sentence):
142 | break
143 |
144 | # find end, must in range [-1, -2]
145 | for end in range(-1, -3, -1):
146 | outputs = tokenizer.decode(input_ids[start:end]).replace(" ", "")
147 | if outputs == sentence:
148 | break
149 |
150 | assert tokenizer.decode(input_ids[start:end]).replace(" ", "") == sentence
151 | print(f"start: {start}; end: {end}")
152 | return start, end
153 |
154 |
155 | # 找到 batch_pos and feature_dim
156 | def find_batchpos_embdim(tokenizer, model, gpu):
157 | sentence = "今天天气真好"
158 | inputs = tokenizer(sentence, return_tensors="pt")
159 | if gpu != -1:
160 | inputs = inputs.to("cuda")
161 |
162 | with torch.no_grad():
163 | outputs = model(
164 | **inputs, output_hidden_states=True
165 | ).hidden_states # for new version 4.5.1
166 | outputs = torch.stack(outputs)[[-1]].sum(dim=0) # sum => [batch, T, D=768]
167 | outputs = outputs.cpu().numpy() # (B, T, D) or (T, B, D)
168 | batch_pos = None
169 | if outputs.shape[0] == 1:
170 | batch_pos = 0
171 | if outputs.shape[1] == 1:
172 | batch_pos = 1
173 | assert batch_pos in [0, 1]
174 | feature_dim = outputs.shape[2]
175 | print(f"batch_pos:{batch_pos}, feature_dim:{feature_dim}")
176 | return batch_pos, feature_dim
177 |
178 |
179 | # main process
180 | def extract_embedding(
181 | model_name,
182 | trans_dir,
183 | save_dir,
184 | feature_level,
185 | gpu=-1,
186 | punc_case=None,
187 | language="chinese",
188 | model_dir=None,
189 | ):
190 |
191 | print("=" * 30 + f' Extracting "{model_name}" ' + "=" * 30)
192 | start_time = time.time()
193 |
194 | # save last four layers
195 | layer_ids = [-4, -3, -2, -1]
196 |
197 | # save_dir
198 | if punc_case is None and language == "chinese" and model_dir is None:
199 | save_dir = os.path.join(save_dir, f"{model_name}-{feature_level[:3]}")
200 | elif punc_case is not None:
201 | save_dir = os.path.join(
202 | save_dir, f"{model_name}-punc{punc_case}-{feature_level[:3]}"
203 | )
204 | elif language == "english":
205 | save_dir = os.path.join(save_dir, f"{model_name}-langeng-{feature_level[:3]}")
206 | elif model_dir is not None:
207 | prefix_name = "-".join(model_dir.split("/")[-2:])
208 | save_dir = os.path.join(
209 | save_dir, f"{prefix_name}-{model_name}-{feature_level[:3]}"
210 | )
211 | if not os.path.exists(save_dir):
212 | os.makedirs(save_dir)
213 |
214 | # load model and tokenizer: offline mode (load cached files) # 函数都一样,但是有些位置的参数就不好压缩
215 | print("Loading pre-trained tokenizer and model...")
216 | if model_dir is None:
217 | model_dir = os.path.join(
218 | config.PATH_TO_PRETRAINED_MODELS, f"transformers/{model_name}"
219 | )
220 |
221 | if model_name in [DEBERTA_LARGE_CHINESE, ALBERT_TINY_CHINESE, ALBERT_SMALL_CHINESE]:
222 | model = AutoModel.from_pretrained(model_dir)
223 | tokenizer = BertTokenizer.from_pretrained(model_dir, use_fast=False)
224 | elif model_name in [WENZHONG_GPT2_CHINESE]:
225 | model = GPT2Model.from_pretrained(model_dir)
226 | tokenizer = GPT2Tokenizer.from_pretrained(model_dir, use_fast=False)
227 | elif model_name in [
228 | LLAMA_7B,
229 | LLAMA_13B,
230 | LLAMA2_7B,
231 | VICUNA_7B,
232 | VICUNA_13B,
233 | ALPACE_13B,
234 | OPT_13B,
235 | BLOOM_7B,
236 | ]:
237 | model = AutoModel.from_pretrained(model_dir)
238 | tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False)
239 | elif model_name in [LLAMA2_13B]:
240 | model = AutoModel.from_pretrained(model_dir, use_safetensors=False)
241 | tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False)
242 | elif model_name in [CHATGLM2_6B, MOSS_7B]:
243 | model = AutoModel.from_pretrained(model_dir, trust_remote_code=True)
244 | if model_dir in [
245 | "pretrain-guhao/merged_chatglm2_2",
246 | "pretrain-guhao/merged_chatglm2_3",
247 | ]:
248 | tokenizer = AutoTokenizer.from_pretrained(
249 | "pretrain-guhao/merged_chatglm2", use_fast=False, trust_remote_code=True
250 | )
251 | else:
252 | tokenizer = AutoTokenizer.from_pretrained(
253 | model_dir, use_fast=False, trust_remote_code=True
254 | )
255 | elif model_name in [BAICHUAN_7B, BAICHUAN_13B, BAICHUAN2_7B, BAICHUAN2_13B]:
256 | model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True)
257 | tokenizer = AutoTokenizer.from_pretrained(
258 | model_dir, use_fast=False, trust_remote_code=True
259 | )
260 | elif model_name in [STABLEML_7B]:
261 | model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True)
262 | tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
263 | else:
264 | model = AutoModel.from_pretrained(model_dir)
265 | tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False)
266 |
267 | # 有 gpu 并且是 LLM,才会增加 half process
268 | if gpu != -1 and model_name in [
269 | LLAMA_7B,
270 | LLAMA_13B,
271 | LLAMA2_7B,
272 | LLAMA2_13B,
273 | VICUNA_7B,
274 | VICUNA_13B,
275 | ALPACE_13B,
276 | OPT_13B,
277 | BLOOM_7B,
278 | CHATGLM2_6B,
279 | MOSS_7B,
280 | BAICHUAN_7B,
281 | FALCON_7B,
282 | BAICHUAN_13B,
283 | STABLEML_7B,
284 | BAICHUAN2_7B,
285 | BAICHUAN2_13B,
286 | ]:
287 | model = model.half()
288 |
289 | # 有 gpu 才会放在cuda上
290 | if gpu != -1:
291 | torch.cuda.set_device(gpu)
292 | model.cuda()
293 | model.eval()
294 |
295 | load_time = time.time()
296 |
297 | print("Calculate embeddings...")
298 | start, end = find_start_end_pos(tokenizer) # only preserve [start:end+1] tokens
299 | batch_pos, feature_dim = find_batchpos_embdim(
300 | tokenizer, model, gpu
301 | ) # find batch pos
302 |
303 | df = pd.read_csv(trans_dir)
304 | for idx, row in df.iterrows():
305 | name = row["name"]
306 | # --------------------------------------------------
307 | if language == "chinese":
308 | sentence = row["chinese"] # process on Chinese
309 | elif language == "english":
310 | sentence = row["english"]
311 | # --------------------------------------------------
312 | print(f"Processing {name} ({idx}/{len(df)})...")
313 |
314 | # extract embedding from sentences
315 | embeddings = []
316 | if pd.isna(sentence) == False and len(sentence) > 0:
317 | inputs = tokenizer(sentence, return_tensors="pt")
318 | if gpu != -1:
319 | inputs = inputs.to("cuda")
320 | with torch.no_grad():
321 | outputs = model(
322 | **inputs, output_hidden_states=True
323 | ).hidden_states # for new version 4.5.1
324 | outputs = torch.stack(outputs)[layer_ids].sum(
325 | dim=0
326 | ) # sum => [batch, T, D=768]
327 | outputs = outputs.cpu().numpy() # (B, T, D)
328 | if batch_pos == 0:
329 | embeddings = outputs[0, start:end]
330 | elif batch_pos == 1:
331 | embeddings = outputs[start:end, 0]
332 |
333 | # align with label timestamp and write csv file
334 | print(f"feature dimension: {feature_dim}")
335 | csv_file = os.path.join(save_dir, f"{name}.npy")
336 | if feature_level == "FRAME":
337 | embeddings = np.array(embeddings).squeeze()
338 | if len(embeddings) == 0:
339 | embeddings = np.zeros((1, feature_dim))
340 | elif len(embeddings.shape) == 1:
341 | embeddings = embeddings[np.newaxis, :]
342 | np.save(csv_file, embeddings)
343 | else:
344 | embeddings = np.array(embeddings).squeeze()
345 | if len(embeddings) == 0:
346 | embeddings = np.zeros((feature_dim,))
347 | elif len(embeddings.shape) == 2:
348 | embeddings = np.mean(embeddings, axis=0)
349 | np.save(csv_file, embeddings)
350 |
351 | end_time = time.time()
352 | print(f"Model load time used: {load_time - start_time:.1f}s.")
353 | print(
354 | f"Total {len(df)} files done! Time used ({model_name}): {end_time - start_time:.1f}s."
355 | )
356 |
357 |
358 | if __name__ == "__main__":
359 |
360 | parser = argparse.ArgumentParser(description="Run.")
361 | parser.add_argument(
362 | "--dataset",
363 | type=str,
364 | default="/sda/xyy/mer/MERTools/MER2023-Dataset-Extended/mer2023-dataset-process/",
365 | help="input dataset",
366 | )
367 | parser.add_argument("--gpu", type=int, default="0", help="gpu id")
368 | parser.add_argument(
369 | "--model_name", type=str, default="bloom-7b1", help="name of pretrained model"
370 | )
371 | parser.add_argument(
372 | "--feature_level",
373 | type=str,
374 | default="UTTERANCE",
375 | choices=["UTTERANCE", "FRAME"],
376 | help="output types",
377 | )
378 | # ------ 临时测试标点符号对于结果的影响 ------
379 | parser.add_argument(
380 | "--punc_case",
381 | type=str,
382 | default=None,
383 | help="test punc impact to the performance",
384 | )
385 | # ------ 临时Language对于结果的影响 ------
386 | parser.add_argument("--language", type=str, default="chinese", help="used language")
387 | # ------ 临时测试外部接受的 model_dir [for gu hao] ------
388 | parser.add_argument(
389 | "--model_dir", type=str, default=None, help="used user-defined model_dir"
390 | )
391 | args = parser.parse_args()
392 |
393 | # (trans_dir, save_dir)
394 | if args.punc_case is None:
395 | trans_dir = os.path.join(args.dataset, "text", "transcription.csv")
396 |
397 | save_dir = os.path.join(args.dataset, "features")
398 |
399 | extract_embedding(
400 | model_name=args.model_name,
401 | trans_dir=trans_dir,
402 | save_dir=save_dir,
403 | feature_level=args.feature_level,
404 | gpu=args.gpu,
405 | punc_case=args.punc_case,
406 | language=args.language,
407 | model_dir=args.model_dir,
408 | )
409 |
--------------------------------------------------------------------------------