├── Baseline_Model.pdf ├── README.md ├── config.json ├── eval_aro_dim_ser.py ├── eval_cat_ser_weighted.py ├── eval_dim_ser.py ├── eval_dim_ser_test3.py ├── eval_dom_dim_ser.py ├── eval_val_dim_ser.py ├── model └── download_models.sh ├── net ├── __init__.py ├── pooling.py └── ser.py ├── process_labels_for_categorical.py ├── requirements.txt ├── run_arousal.sh ├── run_cat.sh ├── run_dim.sh ├── run_dominance.sh ├── run_valence.sh ├── spec-file.txt ├── train_ft_aro_dim_ser.py ├── train_ft_cat_ser_weighted.py ├── train_ft_dim_ser.py ├── train_ft_dom_dim_ser.py ├── train_ft_val_dim_ser.py └── utils ├── __init__.py ├── data ├── __init__.py ├── podcast.py └── wav.py ├── dataset ├── __init__.py ├── collate_fn.py ├── dataset.py └── normalizer.py ├── etc.py └── loss_manager.py /Baseline_Model.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msplabresearch/MSP-Podcast_Challenge/3f60a423b1d3d16c907f1b52e195cbbe32d21976/Baseline_Model.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MSP-Podcast Emotion Challenge - Baselines Training and Evaluation 2 | 3 | This repository contains scripts to train and evaluate baseline models on various tasks including categorical emotion, multi-task emotional attributes, and single-task emotional attributes for arousal, dominance, and valence. 4 | 5 | Link to the paper: [PDF](https://ecs.utdallas.edu/research/researchlabs/msp-lab/publications/Goncalves_2024.pdf) 6 | 7 | Link to the baseline experiments: [Baseline_Model.pdf](Baseline_Model.pdf) 8 | 9 | Link to the challenge: [Odyssey 2024 - Emotion Recognition Challenge](https://www.odyssey2024.org/emotion-recognition-challenge) 10 | 11 | Link to the [submission website](https://lab-msp.com/MSP-Podcast_Competition/leaderboard.php) 12 | 13 | Refer to the links above to sign-up for the challenge, rules, submission deadlines, and file formatting instructions. 14 | ## Environment Setup 15 | 16 | Python version = 3.9.7 17 | 18 | To replicate the environment necessary to run the code, you have two options: 19 | 20 | ### Using Conda 21 | 22 | 1. Ensure that you have [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/) installed. 23 | 2. Create a conda environment using the `spec-file.txt` by running: 24 | conda create --name baseline_env --file spec-file.txt 25 | 3. Activate the environment: 26 | conda activate baseline_env 27 | 4. Make sure to install the transformers library as it is essential for the code to run: 28 | pip install transformers 29 | 30 | 31 | ### Using pip 32 | 33 | 1. Alternatively, you can use `requirements.txt` to install the necessary packages in a virtual environment: 34 | python -m venv myenv 35 | source myenv/bin/activate 36 | pip install -r requirements.txt 37 | 2. Make sure to install the transformers library as it is essential for the code to run: 38 | pip install transformers 39 | 40 | 41 | ## Configuration 42 | 43 | Before running the training or evaluation scripts, check intructions below and update the `config.json` file with the paths to your local audio folder and label CSV file. 44 | 45 | ### Categorical Emotion Recognition Model 46 | 47 | Before running training or evaluation of categorical emotion recognition model. Please execute the script `process_labels_for_categorical.py` to properly format the provided `labels_consensus.csv` file for categorical emotion recognition. Then place the path of the processed .csv file in the `config.json` file to run this configuration. 48 | 49 | ### Attributes Emotion Recognition Model 50 | 51 | The original `labels_consensus.csv` file provided with the dataset can be used as-is for attributes emotion recognition. Please place the path to the `labels_consensu.csv` file in the `config.json` file to run this configuration. 52 | 53 | ## Inferencing 54 | 55 | ### HuggingFace 56 | 57 | If you are only interested in using the pretained models for prediction or feature extraction, we have made the models available on HuggingFace. 58 | 59 | #### Models on HuggingFace 60 | - [x] [Categorical model](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Categorical) 61 | - [x] [Multi-task attribute model](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes) - [Emotional Attributes Prediction App/Demo](https://huggingface.co/spaces/3loi/WavLM-SER-Multi-Baseline-Odyssey2024) 62 | - [x] [Valence model](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Valence) 63 | - [x] [Dominance model](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Dominance) 64 | - [x] [Arousal model](https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Arousal) 65 | 66 | ## Training and Evaluation 67 | 68 | To train or evaluate the models, use the provided shell scripts. Here's how to use each script: 69 | 70 | - `bash run_cat.sh`: Trains or evaluates the categorical emotion recognition baseline. 71 | - `bash run_dim.sh`: Trains or evaluates the multi-task emotional attributes baseline. 72 | - `bash run_arousal.sh`: Trains or evaluates the single-task emotional attribute baseline for arousal. 73 | - `bash run_dominance.sh`: Trains or evaluates the single-task emotional attribute baseline for dominance. 74 | - `bash run_valence.sh`: Trains or evaluates the single-task emotional attribute baseline for valence. 75 | 76 | 77 | ### Models 78 | 79 | The models are to be saved in the `model` folder. If you are evaluating the pretrained models, please download the models using the script provided in `model` folder. 80 | ``` 81 | $ bash download_models.sh 82 | ``` 83 | - Example 1: `bash download_models.sh all` to download all the models. 84 | - Example 2: `bash download_models.sh arousal valence` to download arousal and valence models. 85 | 86 | If you wish to manually download a pre-trained model. Please visit this [website](https://lab-msp.com/MODELS/Odyssey_Baselines/) and download the desired model and place them in the `model` folder. 87 | 88 | Pre-trained models file descriptions: 89 | - "weight_cat_ser.zip" --> Categorical emotion recognition baseline. 90 | - "dim_aro_ser.zip" --> Arousal single-task emotional attribute baseline. 91 | - "dim_dom_ser.zip" --> Dominance single-task emotional attribute baseline. 92 | - "dim_val_ser.zip" --> Valence single-task emotional attribute baseline. 93 | - "dim_ser.zip" --> Multi-task emotional attributes baseline. 94 | 95 | 96 | 97 | ### Evaluation Only 98 | 99 | If you are only evaluating a model and do not wish to train it, comment out the lines related to `train_****_.py` in the respective `.sh` file. 100 | 101 | ### Evaluation and saving results for emotional attributes prediction 102 | 103 | A custom executable sample file for evaluation and results saving has been provided `eval_dim_ser_test3.py`. To execute, just download or train the multi-task emotional attributes baseline as a sample. Then execute `bash run_dim.sh` (NOTE: if you will not be training the entire model again. Please, comment out the training lines in `run_dim.sh` before evaluation). The saved results will be saved in the correct `.csv` format and it will be located in a `results` folder created inside your model path location. 104 | 105 | ## Issues 106 | 107 | If you encounter any issues while setting up, training, or evaluating the models, please open an issue on this repository with a detailed description of the problem. 108 | 109 | --------------------------- 110 | To cite this repository in your works, use the following BibTeX entry: 111 | 112 | ``` 113 | @InProceedings{Goncalves_2024, 114 | author={L. Goncalves and A. N. Salman and A. {Reddy Naini} and L. Moro-Velazquez and T. Thebaud and L. {Paola Garcia} and N. Dehak and B. Sisman and C. Busso}, 115 | title={Odyssey2024 - Speech Emotion Recognition Challenge: Dataset, Baseline Framework, and Results}, 116 | booktitle={Odyssey 2024: The Speaker and Language Recognition Workshop)}, 117 | volume={To appear}, 118 | year={2024}, 119 | month={June}, 120 | address = {Quebec, Canada}, 121 | } 122 | ``` 123 | 124 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "wav_dir": "/path/to/MSP-PODCAST-1.11/Audios", 3 | "label_path": "/path/to/processed_labels.csv" 4 | } -------------------------------------------------------------------------------- /eval_aro_dim_ser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Local modules 3 | import os 4 | import sys 5 | import argparse 6 | # 3rd-Party Modules 7 | import numpy as np 8 | import pickle as pk 9 | import pandas as pd 10 | from tqdm import tqdm 11 | import glob 12 | import librosa 13 | import copy 14 | from time import perf_counter 15 | import csv 16 | 17 | # PyTorch Modules 18 | import torch 19 | import torch.nn as nn 20 | import torch.optim as optim 21 | import torch.nn.functional as F 22 | from torch.utils.data import ConcatDataset, DataLoader 23 | from torch.cuda.amp import GradScaler, autocast 24 | import torch.optim as optim 25 | from transformers import AutoModel 26 | 27 | # Self-Written Modules 28 | sys.path.append(os.getcwd()) 29 | import net 30 | import utils 31 | 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--ssl_type", type=str, default="wavlm-large") 35 | parser.add_argument("--model_path", type=str, default="./model/wavlm-large") 36 | parser.add_argument("--pooling_type", type=str, default="MeanPooling") 37 | parser.add_argument("--head_dim", type=int, default=1024) 38 | parser.add_argument('--store_path') 39 | args = parser.parse_args() 40 | 41 | SSL_TYPE = utils.get_ssl_type(args.ssl_type) 42 | assert SSL_TYPE != None, print("Invalid SSL type!") 43 | MODEL_PATH = args.model_path 44 | 45 | import json 46 | from collections import defaultdict 47 | config_path = "config.json" 48 | with open(config_path, "r") as f: 49 | config = json.load(f) 50 | audio_path = config["wav_dir"] 51 | label_path = config["label_path"] 52 | 53 | total_dataset=dict() 54 | total_dataloader=dict() 55 | for dtype in ["dev"]: 56 | cur_utts, cur_labs = utils.load_adv_arousal(label_path, dtype) 57 | cur_wavs = utils.load_audio(audio_path, cur_utts) 58 | wav_mean, wav_std = utils.load_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 59 | cur_wav_set = utils.WavSet(cur_wavs, wav_mean=wav_mean, wav_std=wav_std) 60 | cur_emo_set = utils.ADV_EmoSet(cur_labs) 61 | total_dataset[dtype] = utils.CombinedSet([cur_wav_set, cur_emo_set, cur_utts]) 62 | total_dataloader[dtype] = DataLoader( 63 | total_dataset[dtype], batch_size=1, shuffle=False, 64 | pin_memory=True, num_workers=4, 65 | collate_fn=utils.collate_fn_wav_lab_mask 66 | ) 67 | 68 | print("Loading pre-trained ", SSL_TYPE, " model...") 69 | 70 | ssl_model = AutoModel.from_pretrained(SSL_TYPE) 71 | ssl_model.freeze_feature_encoder() 72 | ssl_model.load_state_dict(torch.load(MODEL_PATH+"/best_ssl.pt")) 73 | ssl_model.eval(); ssl_model.cuda() 74 | ########## Implement pooling method ########## 75 | feat_dim = ssl_model.config.hidden_size 76 | 77 | pool_net = getattr(net, args.pooling_type) 78 | attention_pool_type_list = ["AttentiveStatisticsPooling"] 79 | if args.pooling_type in attention_pool_type_list: 80 | is_attentive_pooling = True 81 | pool_model = pool_net(feat_dim) 82 | pool_model.load_state_dict(torch.load(MODEL_PATH+"/best_pool.pt")) 83 | else: 84 | is_attentive_pooling = False 85 | pool_model = pool_net() 86 | print(pool_model) 87 | 88 | pool_model.eval(); 89 | pool_model.cuda() 90 | concat_pool_type_list = ["AttentiveStatisticsPooling"] 91 | dh_input_dim = feat_dim * 2 \ 92 | if args.pooling_type in concat_pool_type_list \ 93 | else feat_dim 94 | 95 | ser_model = net.EmotionRegression(dh_input_dim, args.head_dim, 1, 1, dropout=0.5) 96 | ############################################## 97 | ser_model.load_state_dict(torch.load(MODEL_PATH+"/best_ser.pt")) 98 | ser_model.eval(); ser_model.cuda() 99 | 100 | 101 | lm = utils.LogManager() 102 | for dtype in ["dev"]: 103 | lm.alloc_stat_type_list([f"{dtype}_aro"]) 104 | 105 | min_epoch=0 106 | min_loss=1e10 107 | 108 | lm.init_stat() 109 | 110 | ssl_model.eval() 111 | ser_model.eval() 112 | 113 | 114 | INFERENCE_TIME=0 115 | FRAME_SEC = 0 116 | for dtype in ["dev"]: 117 | total_pred = [] 118 | total_y = [] 119 | total_utt = [] 120 | for xy_pair in tqdm(total_dataloader[dtype]): 121 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 122 | y = xy_pair[1]; y=y.cuda(non_blocking=True).float() 123 | mask = xy_pair[2]; mask=mask.cuda(non_blocking=True).float() 124 | fname = xy_pair[3] 125 | 126 | FRAME_SEC += (mask.sum()/16000) 127 | stime = perf_counter() 128 | with torch.no_grad(): 129 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state 130 | ssl = pool_model(ssl, mask) 131 | emo_pred = ser_model(ssl) 132 | 133 | total_pred.append(emo_pred) 134 | total_y.append(y) 135 | total_utt.append(fname) 136 | etime = perf_counter() 137 | INFERENCE_TIME += (etime-stime) 138 | # CCC calculation 139 | total_pred = torch.cat(total_pred, 0) 140 | total_y = torch.cat(total_y, 0) 141 | ccc = utils.CCC_loss(total_pred, total_y) 142 | # Logging 143 | lm.add_torch_stat(f"{dtype}_aro", ccc[0]) 144 | 145 | 146 | data = [] 147 | for y, pred, utt in zip(total_y, total_pred, total_utt): 148 | pred_values = ', '.join([f'{val:.4f}' for val in pred.cpu().numpy().flatten()]) 149 | data.append([utt[0], pred_values]) 150 | 151 | # Writing to CSV file 152 | csv_filename = MODEL_PATH + '/results/' + dtype + '.csv' 153 | with open(csv_filename, mode='w', newline='') as file: 154 | writer = csv.writer(file) 155 | writer.writerow(['FileName', 'Prediction']) 156 | writer.writerows(data) 157 | 158 | lm.print_stat() 159 | print("Duration of whole dev+test set", FRAME_SEC, "sec") 160 | print("Inference time", INFERENCE_TIME, "sec") 161 | print("Inference time per sec", INFERENCE_TIME/FRAME_SEC, "sec") 162 | 163 | os.makedirs(os.path.dirname(args.store_path), exist_ok=True) 164 | with open(args.store_path, 'w') as f: 165 | for dtype in ["dev"]: 166 | aro = str(lm.get_stat(f"{dtype}_aro")) 167 | f.write(aro+"\n") 168 | -------------------------------------------------------------------------------- /eval_cat_ser_weighted.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Local modules 3 | import os 4 | import sys 5 | import argparse 6 | # 3rd-Party Modules 7 | import numpy as np 8 | import pickle as pk 9 | import pandas as pd 10 | from tqdm import tqdm 11 | import glob 12 | import librosa 13 | import copy 14 | import csv 15 | from time import perf_counter 16 | from sklearn.metrics import precision_score, recall_score, f1_score 17 | from sklearn.preprocessing import MultiLabelBinarizer 18 | 19 | 20 | # PyTorch Modules 21 | import torch 22 | import torch.nn as nn 23 | import torch.optim as optim 24 | import torch.nn.functional as F 25 | from torch.utils.data import ConcatDataset, DataLoader 26 | import torch.optim as optim 27 | from transformers import AutoModel 28 | 29 | # Self-Written Modules 30 | sys.path.append(os.getcwd()) 31 | import net 32 | import utils 33 | 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--ssl_type", type=str, default="wavlm-large") 37 | parser.add_argument("--model_path", type=str, default="./model/wavlm-large") 38 | parser.add_argument("--pooling_type", type=str, default="MeanPooling") 39 | parser.add_argument("--head_dim", type=int, default=1024) 40 | parser.add_argument('--store_path') 41 | args = parser.parse_args() 42 | 43 | SSL_TYPE = utils.get_ssl_type(args.ssl_type) 44 | assert SSL_TYPE != None, print("Invalid SSL type!") 45 | MODEL_PATH = args.model_path 46 | 47 | import json 48 | from collections import defaultdict 49 | config_path = "config_cat.json" 50 | with open(config_path, "r") as f: 51 | config = json.load(f) 52 | audio_path = config["wav_dir"] 53 | label_path = config["label_path"] 54 | 55 | import pandas as pd 56 | import numpy as np 57 | 58 | # Load the CSV file 59 | df = pd.read_csv(label_path) 60 | 61 | # Filter out only 'Train' samples 62 | train_df = df[df['Split_Set'] == 'Train'] 63 | 64 | # Classes (emotions) 65 | classes = ['Angry', 'Sad', 'Happy', 'Surprise', 'Fear', 'Disgust', 'Contempt', 'Neutral'] 66 | 67 | # Calculate class frequencies 68 | class_frequencies = train_df[classes].sum().to_dict() 69 | 70 | # Total number of samples 71 | total_samples = len(train_df) 72 | 73 | # Calculate class weights 74 | class_weights = {cls: total_samples / (len(classes) * freq) if freq != 0 else 0 for cls, freq in class_frequencies.items()} 75 | 76 | print(class_weights) 77 | 78 | # Convert to list in the order of classes 79 | weights_list = [class_weights[cls] for cls in classes] 80 | 81 | # Convert to PyTorch tensor 82 | class_weights_tensor = torch.tensor(weights_list, device='cuda', dtype=torch.float) 83 | 84 | 85 | # Print or return the tensor 86 | print(class_weights_tensor) 87 | 88 | total_dataset=dict() 89 | total_dataloader=dict() 90 | for dtype in ["dev"]: 91 | cur_utts, cur_labs = utils.load_cat_emo_label(label_path, dtype) 92 | cur_wavs = utils.load_audio(audio_path, cur_utts) 93 | wav_mean, wav_std = utils.load_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 94 | cur_wav_set = utils.WavSet(cur_wavs, wav_mean=wav_mean, wav_std=wav_std) 95 | cur_emo_set = utils.CAT_EmoSet(cur_labs) 96 | 97 | total_dataset[dtype] = utils.CombinedSet([cur_wav_set, cur_emo_set, cur_utts]) 98 | total_dataloader[dtype] = DataLoader( 99 | total_dataset[dtype], batch_size=1, shuffle=False, 100 | pin_memory=True, num_workers=4, 101 | collate_fn=utils.collate_fn_wav_lab_mask 102 | ) 103 | 104 | print("Loading pre-trained ", SSL_TYPE, " model...") 105 | 106 | ssl_model = AutoModel.from_pretrained(SSL_TYPE) 107 | ssl_model.freeze_feature_encoder() 108 | ssl_model.load_state_dict(torch.load(MODEL_PATH+"/best_ssl.pt")) 109 | ssl_model.eval(); ssl_model.cuda() 110 | ########## Implement pooling method ########## 111 | feat_dim = ssl_model.config.hidden_size 112 | 113 | pool_net = getattr(net, args.pooling_type) 114 | attention_pool_type_list = ["AttentiveStatisticsPooling"] 115 | if args.pooling_type in attention_pool_type_list: 116 | is_attentive_pooling = True 117 | pool_model = pool_net(feat_dim) 118 | pool_model.load_state_dict(torch.load(MODEL_PATH+"/best_pool.pt")) 119 | else: 120 | is_attentive_pooling = False 121 | pool_model = pool_net() 122 | print(pool_model) 123 | 124 | pool_model.eval() 125 | pool_model.cuda() 126 | concat_pool_type_list = ["AttentiveStatisticsPooling"] 127 | dh_input_dim = feat_dim * 2 \ 128 | if args.pooling_type in concat_pool_type_list \ 129 | else feat_dim 130 | 131 | ser_model = net.EmotionRegression(dh_input_dim, args.head_dim, 1, 8, dropout=0.5) 132 | ############################################## 133 | ser_model.load_state_dict(torch.load(MODEL_PATH+"/best_ser.pt")) 134 | ser_model.eval(); ser_model.cuda() 135 | 136 | 137 | lm = utils.LogManager() 138 | for dtype in ["dev"]: 139 | lm.alloc_stat_type_list([f"{dtype}_loss"]) 140 | 141 | min_epoch=0 142 | min_loss=1e10 143 | 144 | lm.init_stat() 145 | 146 | ssl_model.eval() 147 | ser_model.eval() 148 | 149 | if not os.path.exists(MODEL_PATH + '/results'): 150 | os.mkdir(MODEL_PATH + '/results') 151 | INFERENCE_TIME=0 152 | FRAME_SEC = 0 153 | for dtype in ["dev"]: 154 | total_pred = [] 155 | total_y = [] 156 | total_utt = [] 157 | for xy_pair in tqdm(total_dataloader[dtype]): 158 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 159 | y = xy_pair[1]; y=y.max(dim=1)[1]; y=y.cuda(non_blocking=True).long() 160 | mask = xy_pair[2]; mask=mask.cuda(non_blocking=True).float() 161 | fname = xy_pair[3] 162 | 163 | FRAME_SEC += (mask.sum()/16000) 164 | stime = perf_counter() 165 | with torch.no_grad(): 166 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state 167 | ssl = pool_model(ssl, mask) 168 | emo_pred = ser_model(ssl) 169 | 170 | total_pred.append(emo_pred) 171 | total_y.append(y) 172 | total_utt.append(fname) 173 | 174 | etime = perf_counter() 175 | INFERENCE_TIME += (etime-stime) 176 | 177 | def label_to_one_hot(label, num_classes=8): 178 | one_hot = ['0.0'] * num_classes 179 | one_hot[label.item()] = '1.0' 180 | return ','.join(one_hot) 181 | 182 | data = [] 183 | for y, pred, utt in zip(total_y, total_pred, total_utt): 184 | one_hot_label = label_to_one_hot(y.cpu()) 185 | pred_values = ', '.join([f'{val:.4f}' for val in pred.cpu().numpy().flatten()]) 186 | data.append([utt[0], one_hot_label, pred_values]) 187 | 188 | # Writing to CSV file 189 | csv_filename = MODEL_PATH + '/results/' + dtype + '.csv' 190 | with open(csv_filename, mode='w', newline='') as file: 191 | writer = csv.writer(file) 192 | writer.writerow(['Filename', 'Label', 'Prediction']) 193 | writer.writerows(data) 194 | 195 | 196 | ################################## 197 | 198 | # Load the CSV file 199 | df = pd.read_csv(csv_filename) 200 | 201 | # Function to convert string representation of one-hot vectors to numpy arrays 202 | def string_to_array(s): 203 | return np.array([float(i) for i in s.strip('\"').split(',')]) 204 | 205 | # Convert the string representations to numpy arrays 206 | df['Label'] = df['Label'].apply(string_to_array) 207 | df['Prediction'] = df['Prediction'].apply(string_to_array) 208 | 209 | # Use argmax to determine the class with the highest probability 210 | y_true = np.argmax(np.stack(df['Label'].values), axis=1) 211 | y_pred = np.argmax(np.stack(df['Prediction'].values), axis=1) 212 | 213 | # Compute metrics 214 | f1_micro = f1_score(y_true, y_pred, average='micro') 215 | f1_macro = f1_score(y_true, y_pred, average='macro') 216 | precision = precision_score(y_true, y_pred, average='macro') 217 | recall = recall_score(y_true, y_pred, average='macro') 218 | 219 | # Print results 220 | print(f"F1-Micro: {f1_micro}") 221 | print(f"F1-Macro: {f1_macro}") 222 | print(f"Precision: {precision}") 223 | print(f"Recall: {recall}") 224 | 225 | # Save the results in a text file 226 | with open(MODEL_PATH + '/results/' + dtype + '.txt', 'w') as f: 227 | f.write(f"F1-Micro: {f1_micro}\n") 228 | f.write(f"F1-Macro: {f1_macro}\n") 229 | f.write(f"Precision: {precision}\n") 230 | f.write(f"Recall: {recall}\n") 231 | 232 | 233 | 234 | 235 | 236 | 237 | # CCC calculation 238 | total_pred = torch.cat(total_pred, 0) 239 | total_y = torch.cat(total_y, 0) 240 | loss = utils.CE_weight_category(total_pred, total_y, class_weights_tensor) 241 | # Logging 242 | lm.add_torch_stat(f"{dtype}_loss", loss) 243 | 244 | 245 | lm.print_stat() 246 | print("Duration of whole dev+test set", FRAME_SEC, "sec") 247 | print("Inference time", INFERENCE_TIME, "sec") 248 | print("Inference time per sec", INFERENCE_TIME/FRAME_SEC, "sec") 249 | 250 | os.makedirs(os.path.dirname(args.store_path), exist_ok=True) 251 | with open(args.store_path, 'w') as f: 252 | for dtype in ["dev", "test"]: 253 | loss = str(lm.get_stat(f"{dtype}_loss")) 254 | f.write(loss+"\n") 255 | -------------------------------------------------------------------------------- /eval_dim_ser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Local modules 3 | import os 4 | import sys 5 | import argparse 6 | # 3rd-Party Modules 7 | import numpy as np 8 | import pickle as pk 9 | import pandas as pd 10 | from tqdm import tqdm 11 | import glob 12 | import librosa 13 | import copy 14 | from time import perf_counter 15 | 16 | # PyTorch Modules 17 | import torch 18 | import torch.nn as nn 19 | import torch.optim as optim 20 | import torch.nn.functional as F 21 | from torch.utils.data import ConcatDataset, DataLoader 22 | from torch.cuda.amp import GradScaler, autocast 23 | import torch.optim as optim 24 | from transformers import AutoModel 25 | 26 | # Self-Written Modules 27 | sys.path.append(os.getcwd()) 28 | import net 29 | import utils 30 | 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--ssl_type", type=str, default="wavlm-large") 34 | parser.add_argument("--model_path", type=str, default="./model/wavlm-large") 35 | parser.add_argument("--pooling_type", type=str, default="MeanPooling") 36 | parser.add_argument("--head_dim", type=int, default=1024) 37 | parser.add_argument('--store_path') 38 | args = parser.parse_args() 39 | 40 | SSL_TYPE = utils.get_ssl_type(args.ssl_type) 41 | assert SSL_TYPE != None, print("Invalid SSL type!") 42 | MODEL_PATH = args.model_path 43 | 44 | import json 45 | from collections import defaultdict 46 | config_path = "config.json" 47 | with open(config_path, "r") as f: 48 | config = json.load(f) 49 | audio_path = config["wav_dir"] 50 | label_path = config["label_path"] 51 | 52 | total_dataset=dict() 53 | total_dataloader=dict() 54 | for dtype in ["dev"]: 55 | cur_utts, cur_labs = utils.load_adv_emo_label(label_path, dtype) 56 | cur_wavs = utils.load_audio(audio_path, cur_utts) 57 | wav_mean, wav_std = utils.load_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 58 | cur_wav_set = utils.WavSet(cur_wavs, wav_mean=wav_mean, wav_std=wav_std) 59 | cur_emo_set = utils.ADV_EmoSet(cur_labs) 60 | total_dataset[dtype] = utils.CombinedSet([cur_wav_set, cur_emo_set, cur_utts]) 61 | total_dataloader[dtype] = DataLoader( 62 | total_dataset[dtype], batch_size=1, shuffle=False, 63 | pin_memory=True, num_workers=4, 64 | collate_fn=utils.collate_fn_wav_lab_mask 65 | ) 66 | 67 | print("Loading pre-trained ", SSL_TYPE, " model...") 68 | 69 | ssl_model = AutoModel.from_pretrained(SSL_TYPE) 70 | ssl_model.freeze_feature_encoder() 71 | ssl_model.load_state_dict(torch.load(MODEL_PATH+"/final_ssl.pt")) 72 | ssl_model.eval(); ssl_model.cuda() 73 | ########## Implement pooling method ########## 74 | feat_dim = ssl_model.config.hidden_size 75 | 76 | pool_net = getattr(net, args.pooling_type) 77 | attention_pool_type_list = ["AttentiveStatisticsPooling"] 78 | if args.pooling_type in attention_pool_type_list: 79 | is_attentive_pooling = True 80 | pool_model = pool_net(feat_dim) 81 | pool_model.load_state_dict(torch.load(MODEL_PATH+"/final_pool.pt")) 82 | else: 83 | is_attentive_pooling = False 84 | pool_model = pool_net() 85 | print(pool_model) 86 | 87 | pool_model.eval(); 88 | pool_model.cuda() 89 | concat_pool_type_list = ["AttentiveStatisticsPooling"] 90 | dh_input_dim = feat_dim * 2 \ 91 | if args.pooling_type in concat_pool_type_list \ 92 | else feat_dim 93 | 94 | ser_model = net.EmotionRegression(dh_input_dim, args.head_dim, 1, 3, dropout=0.5) 95 | ############################################## 96 | ser_model.load_state_dict(torch.load(MODEL_PATH+"/final_ser.pt")) 97 | ser_model.eval(); ser_model.cuda() 98 | 99 | 100 | lm = utils.LogManager() 101 | for dtype in ["dev"]: 102 | lm.alloc_stat_type_list([f"{dtype}_aro", f"{dtype}_dom", f"{dtype}_val"]) 103 | 104 | min_epoch=0 105 | min_loss=1e10 106 | 107 | lm.init_stat() 108 | 109 | ssl_model.eval() 110 | ser_model.eval() 111 | 112 | 113 | INFERENCE_TIME=0 114 | FRAME_SEC = 0 115 | for dtype in ["dev"]: 116 | total_pred = [] 117 | total_y = [] 118 | for xy_pair in tqdm(total_dataloader[dtype]): 119 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 120 | y = xy_pair[1]; y=y.cuda(non_blocking=True).float() 121 | mask = xy_pair[2]; mask=mask.cuda(non_blocking=True).float() 122 | 123 | FRAME_SEC += (mask.sum()/16000) 124 | stime = perf_counter() 125 | with torch.no_grad(): 126 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state 127 | ssl = pool_model(ssl, mask) 128 | emo_pred = ser_model(ssl) 129 | 130 | total_pred.append(emo_pred) 131 | total_y.append(y) 132 | etime = perf_counter() 133 | INFERENCE_TIME += (etime-stime) 134 | # CCC calculation 135 | total_pred = torch.cat(total_pred, 0) 136 | total_y = torch.cat(total_y, 0) 137 | ccc = utils.CCC_loss(total_pred, total_y) 138 | # Logging 139 | lm.add_torch_stat(f"{dtype}_aro", ccc[0]) 140 | lm.add_torch_stat(f"{dtype}_dom", ccc[1]) 141 | lm.add_torch_stat(f"{dtype}_val", ccc[2]) 142 | 143 | lm.print_stat() 144 | print("Duration of whole dev+test set", FRAME_SEC, "sec") 145 | print("Inference time", INFERENCE_TIME, "sec") 146 | print("Inference time per sec", INFERENCE_TIME/FRAME_SEC, "sec") 147 | 148 | os.makedirs(os.path.dirname(args.store_path), exist_ok=True) 149 | with open(args.store_path, 'w') as f: 150 | for dtype in ["dev"]: 151 | aro = str(lm.get_stat(f"{dtype}_aro")) 152 | dom = str(lm.get_stat(f"{dtype}_dom")) 153 | val = str(lm.get_stat(f"{dtype}_val")) 154 | f.write(aro+","+dom+","+val+"\n") 155 | -------------------------------------------------------------------------------- /eval_dim_ser_test3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Local modules 3 | import os 4 | import sys 5 | import argparse 6 | # 3rd-Party Modules 7 | from tqdm import tqdm 8 | from time import perf_counter 9 | 10 | # PyTorch Modules 11 | import torch 12 | from torch.utils.data import DataLoader 13 | from transformers import AutoModel 14 | import csv 15 | 16 | # Self-Written Modules 17 | sys.path.append(os.getcwd()) 18 | import net 19 | import utils 20 | 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--ssl_type", type=str, default="wavlm-large") 24 | parser.add_argument("--model_path", type=str, default="./model/wavlm-large") 25 | parser.add_argument("--pooling_type", type=str, default="MeanPooling") 26 | parser.add_argument("--head_dim", type=int, default=1024) 27 | parser.add_argument('--store_path') 28 | args = parser.parse_args() 29 | 30 | SSL_TYPE = utils.get_ssl_type(args.ssl_type) 31 | assert SSL_TYPE != None, print("Invalid SSL type!") 32 | MODEL_PATH = args.model_path 33 | 34 | import json 35 | from collections import defaultdict 36 | config_path = "config.json" 37 | with open(config_path, "r") as f: 38 | config = json.load(f) 39 | audio_path = config["wav_dir"] 40 | label_path = config["label_path"] 41 | 42 | files_test3 = [filename for filename in os.listdir(audio_path) if 'test3' in filename] 43 | 44 | dtype = "test3" 45 | 46 | total_dataset=dict() 47 | total_dataloader=dict() 48 | 49 | cur_wavs = utils.load_audio(audio_path, files_test3) 50 | wav_mean, wav_std = utils.load_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 51 | cur_wav_set = utils.WavSet(cur_wavs, wav_mean=wav_mean, wav_std=wav_std) 52 | total_dataset[dtype] = utils.CombinedSet([cur_wav_set, files_test3]) 53 | total_dataloader[dtype] = DataLoader( 54 | total_dataset[dtype], batch_size=1, shuffle=False, 55 | pin_memory=True, num_workers=4, 56 | collate_fn=utils.collate_fn_wav_test3 57 | ) 58 | 59 | print("Loading pre-trained ", SSL_TYPE, " model...") 60 | 61 | ssl_model = AutoModel.from_pretrained(SSL_TYPE) 62 | ssl_model.freeze_feature_encoder() 63 | ssl_model.load_state_dict(torch.load(MODEL_PATH+"/final_ssl.pt")) 64 | ssl_model.eval(); ssl_model.cuda() 65 | ########## Implement pooling method ########## 66 | feat_dim = ssl_model.config.hidden_size 67 | 68 | pool_net = getattr(net, args.pooling_type) 69 | attention_pool_type_list = ["AttentiveStatisticsPooling"] 70 | if args.pooling_type in attention_pool_type_list: 71 | is_attentive_pooling = True 72 | pool_model = pool_net(feat_dim) 73 | pool_model.load_state_dict(torch.load(MODEL_PATH+"/final_pool.pt")) 74 | else: 75 | is_attentive_pooling = False 76 | pool_model = pool_net() 77 | print(pool_model) 78 | 79 | pool_model.eval(); 80 | pool_model.cuda() 81 | concat_pool_type_list = ["AttentiveStatisticsPooling"] 82 | dh_input_dim = feat_dim * 2 \ 83 | if args.pooling_type in concat_pool_type_list \ 84 | else feat_dim 85 | 86 | ser_model = net.EmotionRegression(dh_input_dim, args.head_dim, 1, 3, dropout=0.5) 87 | ############################################## 88 | ser_model.load_state_dict(torch.load(MODEL_PATH+"/final_ser.pt")) 89 | ser_model.eval(); ser_model.cuda() 90 | 91 | 92 | lm = utils.LogManager() 93 | for dtype in ["test3"]: 94 | lm.alloc_stat_type_list([f"{dtype}_aro", f"{dtype}_dom", f"{dtype}_val"]) 95 | 96 | min_epoch=0 97 | min_loss=1e10 98 | 99 | lm.init_stat() 100 | 101 | ssl_model.eval() 102 | ser_model.eval() 103 | 104 | 105 | INFERENCE_TIME=0 106 | FRAME_SEC = 0 107 | for dtype in ["test3"]: 108 | total_pred = [] 109 | total_utt = [] 110 | for xy_pair in tqdm(total_dataloader[dtype]): 111 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 112 | mask = xy_pair[1]; mask=mask.cuda(non_blocking=True).float() 113 | utts = xy_pair[2] 114 | 115 | 116 | FRAME_SEC += (mask.sum()/16000) 117 | stime = perf_counter() 118 | with torch.no_grad(): 119 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state 120 | ssl = pool_model(ssl, mask) 121 | emo_pred = ser_model(ssl) 122 | 123 | total_pred.append(emo_pred) 124 | total_utt.append(utts) 125 | etime = perf_counter() 126 | INFERENCE_TIME += (etime-stime) 127 | 128 | # CCC calculation 129 | total_pred = torch.cat(total_pred, 0) 130 | 131 | data = [] 132 | for pred, utt in zip(total_pred, total_utt): 133 | print(pred) 134 | pred_values = pred.cpu().tolist() 135 | print(pred_values) 136 | data.append([utt[0], min(max(1, pred_values[0] * 6 + 1), 7),min(max(1, pred_values[2] * 6 + 1), 7),min(max(1, pred_values[1] * 6 + 1), 7)]) 137 | 138 | # print(data) 139 | # Writing to CSV file 140 | os.makedirs(MODEL_PATH + '/results', exist_ok=True) 141 | csv_filename = MODEL_PATH + '/results/' + dtype + '.csv' 142 | with open(csv_filename, mode='w', newline='') as file: 143 | writer = csv.writer(file) 144 | writer.writerow(["FileName", "EmoAct", "EmoVal", "EmoDom"]) 145 | writer.writerows(data) 146 | 147 | 148 | lm.print_stat() 149 | print("Duration of whole dev+test set", FRAME_SEC, "sec") 150 | print("Inference time", INFERENCE_TIME, "sec") 151 | print("Inference time per sec", INFERENCE_TIME/FRAME_SEC, "sec") 152 | 153 | -------------------------------------------------------------------------------- /eval_dom_dim_ser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Local modules 3 | import os 4 | import sys 5 | import argparse 6 | # 3rd-Party Modules 7 | import numpy as np 8 | import pickle as pk 9 | import pandas as pd 10 | from tqdm import tqdm 11 | import glob 12 | import librosa 13 | import copy 14 | import csv 15 | from time import perf_counter 16 | 17 | # PyTorch Modules 18 | import torch 19 | import torch.nn as nn 20 | import torch.optim as optim 21 | import torch.nn.functional as F 22 | from torch.utils.data import ConcatDataset, DataLoader 23 | from torch.cuda.amp import GradScaler, autocast 24 | import torch.optim as optim 25 | from transformers import AutoModel 26 | 27 | # Self-Written Modules 28 | sys.path.append(os.getcwd()) 29 | import net 30 | import utils 31 | 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--ssl_type", type=str, default="wavlm-large") 35 | parser.add_argument("--model_path", type=str, default="./model/wavlm-large") 36 | parser.add_argument("--pooling_type", type=str, default="MeanPooling") 37 | parser.add_argument("--head_dim", type=int, default=1024) 38 | parser.add_argument('--store_path') 39 | args = parser.parse_args() 40 | 41 | SSL_TYPE = utils.get_ssl_type(args.ssl_type) 42 | assert SSL_TYPE != None, print("Invalid SSL type!") 43 | MODEL_PATH = args.model_path 44 | 45 | import json 46 | from collections import defaultdict 47 | config_path = "config.json" 48 | with open(config_path, "r") as f: 49 | config = json.load(f) 50 | audio_path = config["wav_dir"] 51 | label_path = config["label_path"] 52 | 53 | total_dataset=dict() 54 | total_dataloader=dict() 55 | for dtype in ["dev"]: 56 | cur_utts, cur_labs = utils.load_adv_dominance(label_path, dtype) 57 | cur_wavs = utils.load_audio(audio_path, cur_utts) 58 | wav_mean, wav_std = utils.load_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 59 | cur_wav_set = utils.WavSet(cur_wavs, wav_mean=wav_mean, wav_std=wav_std) 60 | cur_emo_set = utils.ADV_EmoSet(cur_labs) 61 | total_dataset[dtype] = utils.CombinedSet([cur_wav_set, cur_emo_set,cur_utts]) 62 | total_dataloader[dtype] = DataLoader( 63 | total_dataset[dtype], batch_size=1, shuffle=False, 64 | pin_memory=True, num_workers=4, 65 | collate_fn=utils.collate_fn_wav_lab_mask 66 | ) 67 | 68 | print("Loading pre-trained ", SSL_TYPE, " model...") 69 | 70 | ssl_model = AutoModel.from_pretrained(SSL_TYPE) 71 | ssl_model.freeze_feature_encoder() 72 | ssl_model.load_state_dict(torch.load(MODEL_PATH+"/final_ssl.pt")) 73 | ssl_model.eval(); ssl_model.cuda() 74 | ########## Implement pooling method ########## 75 | feat_dim = ssl_model.config.hidden_size 76 | 77 | pool_net = getattr(net, args.pooling_type) 78 | attention_pool_type_list = ["AttentiveStatisticsPooling"] 79 | if args.pooling_type in attention_pool_type_list: 80 | is_attentive_pooling = True 81 | pool_model = pool_net(feat_dim) 82 | pool_model.load_state_dict(torch.load(MODEL_PATH+"/final_pool.pt")) 83 | else: 84 | is_attentive_pooling = False 85 | pool_model = pool_net() 86 | print(pool_model) 87 | 88 | pool_model.eval(); 89 | pool_model.cuda() 90 | concat_pool_type_list = ["AttentiveStatisticsPooling"] 91 | dh_input_dim = feat_dim * 2 \ 92 | if args.pooling_type in concat_pool_type_list \ 93 | else feat_dim 94 | 95 | ser_model = net.EmotionRegression(dh_input_dim, args.head_dim, 1, 1, dropout=0.5) 96 | ############################################## 97 | ser_model.load_state_dict(torch.load(MODEL_PATH+"/final_ser.pt")) 98 | ser_model.eval(); ser_model.cuda() 99 | 100 | 101 | lm = utils.LogManager() 102 | for dtype in ["dev"]: 103 | lm.alloc_stat_type_list([f"{dtype}_dom"]) 104 | 105 | min_epoch=0 106 | min_loss=1e10 107 | 108 | lm.init_stat() 109 | 110 | ssl_model.eval() 111 | ser_model.eval() 112 | 113 | 114 | INFERENCE_TIME=0 115 | FRAME_SEC = 0 116 | for dtype in ["dev"]: 117 | total_pred = [] 118 | total_y = [] 119 | total_utt = [] 120 | for xy_pair in tqdm(total_dataloader[dtype]): 121 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 122 | y = xy_pair[1]; y=y.cuda(non_blocking=True).float() 123 | mask = xy_pair[2]; mask=mask.cuda(non_blocking=True).float() 124 | fname = xy_pair[3] 125 | 126 | FRAME_SEC += (mask.sum()/16000) 127 | stime = perf_counter() 128 | with torch.no_grad(): 129 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state 130 | ssl = pool_model(ssl, mask) 131 | emo_pred = ser_model(ssl) 132 | 133 | total_pred.append(emo_pred) 134 | total_y.append(y) 135 | total_utt.append(fname) 136 | etime = perf_counter() 137 | INFERENCE_TIME += (etime-stime) 138 | # CCC calculation 139 | total_pred = torch.cat(total_pred, 0) 140 | total_y = torch.cat(total_y, 0) 141 | ccc = utils.CCC_loss(total_pred, total_y) 142 | # Logging 143 | lm.add_torch_stat(f"{dtype}_dom", ccc[0]) 144 | 145 | data = [] 146 | for y, pred, utt in zip(total_y, total_pred, total_utt): 147 | pred_values = ', '.join([f'{val:.4f}' for val in pred.cpu().numpy().flatten()]) 148 | data.append([utt[0], pred_values]) 149 | 150 | # Writing to CSV file 151 | csv_filename = MODEL_PATH + '/results/' + dtype + '.csv' 152 | with open(csv_filename, mode='w', newline='') as file: 153 | writer = csv.writer(file) 154 | writer.writerow(['FileName', 'Prediction']) 155 | writer.writerows(data) 156 | 157 | lm.print_stat() 158 | print("Duration of whole dev+test set", FRAME_SEC, "sec") 159 | print("Inference time", INFERENCE_TIME, "sec") 160 | print("Inference time per sec", INFERENCE_TIME/FRAME_SEC, "sec") 161 | 162 | os.makedirs(os.path.dirname(args.store_path), exist_ok=True) 163 | with open(args.store_path, 'w') as f: 164 | for dtype in ["dev"]: 165 | dom = str(lm.get_stat(f"{dtype}_dom")) 166 | f.write(dom+"\n") 167 | -------------------------------------------------------------------------------- /eval_val_dim_ser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Local modules 3 | import os 4 | import sys 5 | import argparse 6 | # 3rd-Party Modules 7 | import numpy as np 8 | import pickle as pk 9 | import pandas as pd 10 | from tqdm import tqdm 11 | import glob 12 | import librosa 13 | import copy 14 | import csv 15 | from time import perf_counter 16 | 17 | # PyTorch Modules 18 | import torch 19 | import torch.nn as nn 20 | import torch.optim as optim 21 | import torch.nn.functional as F 22 | from torch.utils.data import ConcatDataset, DataLoader 23 | from torch.cuda.amp import GradScaler, autocast 24 | import torch.optim as optim 25 | from transformers import AutoModel 26 | 27 | # Self-Written Modules 28 | sys.path.append(os.getcwd()) 29 | import net 30 | import utils 31 | 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--ssl_type", type=str, default="wavlm-large") 35 | parser.add_argument("--model_path", type=str, default="./model/wavlm-large") 36 | parser.add_argument("--pooling_type", type=str, default="MeanPooling") 37 | parser.add_argument("--head_dim", type=int, default=1024) 38 | parser.add_argument('--store_path') 39 | args = parser.parse_args() 40 | 41 | SSL_TYPE = utils.get_ssl_type(args.ssl_type) 42 | assert SSL_TYPE != None, print("Invalid SSL type!") 43 | MODEL_PATH = args.model_path 44 | 45 | import json 46 | from collections import defaultdict 47 | config_path = "config.json" 48 | with open(config_path, "r") as f: 49 | config = json.load(f) 50 | audio_path = config["wav_dir"] 51 | label_path = config["label_path"] 52 | 53 | total_dataset=dict() 54 | total_dataloader=dict() 55 | for dtype in ["dev"]: 56 | cur_utts, cur_labs = utils.load_adv_valence(label_path, dtype) 57 | cur_wavs = utils.load_audio(audio_path, cur_utts) 58 | wav_mean, wav_std = utils.load_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 59 | cur_wav_set = utils.WavSet(cur_wavs, wav_mean=wav_mean, wav_std=wav_std) 60 | cur_emo_set = utils.ADV_EmoSet(cur_labs) 61 | total_dataset[dtype] = utils.CombinedSet([cur_wav_set, cur_emo_set, cur_utts]) 62 | total_dataloader[dtype] = DataLoader( 63 | total_dataset[dtype], batch_size=1, shuffle=False, 64 | pin_memory=True, num_workers=4, 65 | collate_fn=utils.collate_fn_wav_lab_mask 66 | ) 67 | 68 | print("Loading pre-trained ", SSL_TYPE, " model...") 69 | 70 | ssl_model = AutoModel.from_pretrained(SSL_TYPE)# 71 | ssl_model.freeze_feature_encoder() 72 | ssl_model.load_state_dict(torch.load(MODEL_PATH+"/best_ssl.pt")) 73 | ssl_model.eval(); ssl_model.cuda() 74 | ########## Implement pooling method ########## 75 | feat_dim = ssl_model.config.hidden_size 76 | 77 | pool_net = getattr(net, args.pooling_type) 78 | attention_pool_type_list = ["AttentiveStatisticsPooling"] 79 | if args.pooling_type in attention_pool_type_list: 80 | is_attentive_pooling = True 81 | pool_model = pool_net(feat_dim) 82 | pool_model.load_state_dict(torch.load(MODEL_PATH+"/best_pool.pt")) 83 | else: 84 | is_attentive_pooling = False 85 | pool_model = pool_net() 86 | print(pool_model) 87 | 88 | pool_model.eval(); 89 | pool_model.cuda() 90 | concat_pool_type_list = ["AttentiveStatisticsPooling"] 91 | dh_input_dim = feat_dim * 2 \ 92 | if args.pooling_type in concat_pool_type_list \ 93 | else feat_dim 94 | 95 | ser_model = net.EmotionRegression(dh_input_dim, args.head_dim, 1, 1, dropout=0.5) 96 | ############################################## 97 | ser_model.load_state_dict(torch.load(MODEL_PATH+"/best_ser.pt")) 98 | ser_model.eval(); ser_model.cuda() 99 | 100 | 101 | lm = utils.LogManager() 102 | for dtype in ["dev"]: 103 | lm.alloc_stat_type_list([f"{dtype}_val"]) 104 | 105 | min_epoch=0 106 | min_loss=1e10 107 | 108 | lm.init_stat() 109 | 110 | ssl_model.eval() 111 | ser_model.eval() 112 | 113 | 114 | INFERENCE_TIME=0 115 | FRAME_SEC = 0 116 | for dtype in ["dev"]: 117 | total_pred = [] 118 | total_y = [] 119 | total_utt = [] 120 | for xy_pair in tqdm(total_dataloader[dtype]): 121 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 122 | y = xy_pair[1]; y=y.cuda(non_blocking=True).float() 123 | mask = xy_pair[2]; mask=mask.cuda(non_blocking=True).float() 124 | fname = xy_pair[3] 125 | 126 | FRAME_SEC += (mask.sum()/16000) 127 | stime = perf_counter() 128 | with torch.no_grad(): 129 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state 130 | ssl = pool_model(ssl, mask) 131 | emo_pred = ser_model(ssl) 132 | 133 | total_pred.append(emo_pred) 134 | total_y.append(y) 135 | total_utt.append(fname) 136 | etime = perf_counter() 137 | INFERENCE_TIME += (etime-stime) 138 | # CCC calculation 139 | total_pred = torch.cat(total_pred, 0) 140 | total_y = torch.cat(total_y, 0) 141 | ccc = utils.CCC_loss(total_pred, total_y) 142 | # Logging 143 | lm.add_torch_stat(f"{dtype}_val", ccc[0]) 144 | 145 | data = [] 146 | for y, pred, utt in zip(total_y, total_pred, total_utt): 147 | pred_values = ', '.join([f'{val:.4f}' for val in pred.cpu().numpy().flatten()]) 148 | data.append([utt[0], pred_values]) 149 | 150 | # Writing to CSV file 151 | csv_filename = MODEL_PATH + '/results/' + dtype + '.csv' 152 | with open(csv_filename, mode='w', newline='') as file: 153 | writer = csv.writer(file) 154 | writer.writerow(['FileName', 'Prediction']) 155 | writer.writerows(data) 156 | 157 | lm.print_stat() 158 | print("Duration of whole dev+test set", FRAME_SEC, "sec") 159 | print("Inference time", INFERENCE_TIME, "sec") 160 | print("Inference time per sec", INFERENCE_TIME/FRAME_SEC, "sec") 161 | 162 | os.makedirs(os.path.dirname(args.store_path), exist_ok=True) 163 | with open(args.store_path, 'w') as f: 164 | for dtype in ["dev"]: 165 | val = str(lm.get_stat(f"{dtype}_val")) 166 | f.write(val+"\n") 167 | -------------------------------------------------------------------------------- /model/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if exactly one or more arguments are provided 4 | if [ "$#" -lt 1 ]; then 5 | echo "Usage: $0 {categorical|arousal|dominance|valence|multitask|all}" 6 | exit 1 7 | fi 8 | 9 | # Function definitions for each task 10 | do_categorical() { 11 | echo "Downloading categorical model" 12 | wget https://lab-msp.com/MODELS/Odyssey_Baselines/weight_cat_ser.zip 13 | unzip weight_cat_ser.zip 14 | rm weight_cat_ser.zip 15 | } 16 | 17 | do_arousal() { 18 | echo "Downloading arousal model" 19 | wget https://lab-msp.com/MODELS/Odyssey_Baselines/dim_aro_ser.zip 20 | unzip dim_aro_ser.zip 21 | rm dim_aro_ser.zip 22 | } 23 | 24 | do_dominance() { 25 | echo "Downloading dominance model" 26 | wget https://lab-msp.com/MODELS/Odyssey_Baselines/dim_dom_ser.zip 27 | unzip dim_dom_ser.zip 28 | rm dim_dom_ser.zip 29 | } 30 | 31 | do_valence() { 32 | echo "Downloading valence model" 33 | wget https://lab-msp.com/MODELS/Odyssey_Baselines/dim_val_ser.zip 34 | unzip dim_val_ser.zip 35 | rm dim_val_ser.zip 36 | } 37 | 38 | do_multitask() { 39 | echo "Downloading multitask model" 40 | wget https://lab-msp.com/MODELS/Odyssey_Baselines/dim_ser.zip 41 | unzip dim_ser.zip 42 | rm dim_ser.zip 43 | } 44 | 45 | # Main logic to process the input argument/s 46 | for arg in "$@" 47 | do 48 | case $1 in 49 | categorical) 50 | do_categorical 51 | ;; 52 | arousal) 53 | do_arousal 54 | ;; 55 | dominance) 56 | do_dominance 57 | ;; 58 | valence) 59 | do_valence 60 | ;; 61 | multitask) 62 | do_multitask 63 | ;; 64 | all) 65 | do_categorical 66 | do_arousal 67 | do_dominance 68 | do_valence 69 | do_multitask 70 | ;; 71 | *) 72 | echo "Invalid argument: $1" 73 | echo "Usage: $0 {categorical|arousal|dominance|valence|multitask|all}" 74 | exit 2 75 | ;; 76 | esac 77 | done 78 | 79 | exit 0 80 | -------------------------------------------------------------------------------- /net/__init__.py: -------------------------------------------------------------------------------- 1 | from .ser import * 2 | from .pooling import * -------------------------------------------------------------------------------- /net/pooling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common pooling methods 3 | 4 | Authors: 5 | * Leo 2022 6 | * Haibin Wu 2022 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | __all__ = [ 14 | "MeanPooling", 15 | "AttentiveStatisticsPooling" 16 | ] 17 | 18 | 19 | class Pooling(nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | def compute_length_from_mask(self, mask): 23 | """ 24 | mask: (batch_size, T) 25 | Assuming that the sampling rate is 16kHz, the frame shift is 20ms 26 | """ 27 | wav_lens = torch.sum(mask, dim=1) # (batch_size, ) 28 | feat_lens = torch.div(wav_lens-1, 16000*0.02, rounding_mode="floor") + 1 29 | feat_lens = feat_lens.int().tolist() 30 | return feat_lens 31 | 32 | def forward(self, x, mask): 33 | raise NotImplementedError 34 | 35 | class MeanPooling(Pooling): 36 | def __init__(self): 37 | super().__init__() 38 | def forward(self, xs, mask): 39 | """ 40 | xs: (batch_size, T, feat_dim) 41 | mask: (batch_size, T) 42 | 43 | => output: (batch_size, feat_dim) 44 | """ 45 | feat_lens = self.compute_length_from_mask(mask) 46 | pooled_list = [] 47 | for x, feat_len in zip(xs, feat_lens): 48 | pooled = torch.mean(x[:feat_len], dim=0) # (feat_dim, ) 49 | pooled_list.append(pooled) 50 | pooled = torch.stack(pooled_list, dim=0) # (batch_size, feat_dim) 51 | return pooled 52 | 53 | 54 | class AttentiveStatisticsPooling(Pooling): 55 | """ 56 | AttentiveStatisticsPooling 57 | Paper: Attentive Statistics Pooling for Deep Speaker Embedding 58 | Link: https://arxiv.org/pdf/1803.10963.pdf 59 | """ 60 | def __init__(self, input_size): 61 | super().__init__() 62 | self._indim = input_size 63 | self.sap_linear = nn.Linear(input_size, input_size) 64 | self.attention = nn.Parameter(torch.FloatTensor(input_size, 1)) 65 | torch.nn.init.normal_(self.attention, mean=0, std=1) 66 | 67 | def forward(self, xs, mask): 68 | """ 69 | xs: (batch_size, T, feat_dim) 70 | mask: (batch_size, T) 71 | 72 | => output: (batch_size, feat_dim*2) 73 | """ 74 | feat_lens = self.compute_length_from_mask(mask) 75 | pooled_list = [] 76 | for x, feat_len in zip(xs, feat_lens): 77 | x = x[:feat_len].unsqueeze(0) 78 | h = torch.tanh(self.sap_linear(x)) 79 | w = torch.matmul(h, self.attention).squeeze(dim=2) 80 | w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1) 81 | mu = torch.sum(x * w, dim=1) 82 | rh = torch.sqrt((torch.sum((x**2) * w, dim=1) - mu**2).clamp(min=1e-5)) 83 | x = torch.cat((mu, rh), 1).squeeze(0) 84 | pooled_list.append(x) 85 | return torch.stack(pooled_list) 86 | 87 | 88 | -------------------------------------------------------------------------------- /net/ser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class EmotionRegression(nn.Module): 6 | def __init__(self, *args, **kwargs): 7 | super(EmotionRegression, self).__init__() 8 | input_dim = args[0] 9 | hidden_dim = args[1] 10 | num_layers = args[2] 11 | output_dim = args[3] 12 | p = kwargs.get("dropout", 0.5) 13 | 14 | self.fc=nn.ModuleList([ 15 | nn.Sequential( 16 | nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(p) 17 | ) 18 | ]) 19 | for lidx in range(num_layers-1): 20 | self.fc.append( 21 | nn.Sequential( 22 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(p) 23 | ) 24 | ) 25 | self.out = nn.Sequential( 26 | nn.Linear(hidden_dim, output_dim) 27 | ) 28 | 29 | self.inp_drop = nn.Dropout(p) 30 | def get_repr(self, x): 31 | h = self.inp_drop(x) 32 | for lidx, fc in enumerate(self.fc): 33 | h=fc(h) 34 | return h 35 | 36 | def forward(self, x): 37 | h=self.get_repr(x) 38 | result = self.out(h) 39 | return result -------------------------------------------------------------------------------- /process_labels_for_categorical.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | # Use this code to create a .csv file with the necessary format needed for 4 | # categorical emotion recognition model 5 | 6 | # Load Original label_consensus.csv file provided with dataset 7 | df = pd.read_csv('path/to/original/label_consensus.csv') 8 | 9 | # Define the emotions 10 | emotions = ["Angry", "Sad", "Happy", "Surprise", "Fear", "Disgust", "Contempt", "Neutral"] 11 | emotion_codes = ["A", "S", "H", "U", "F", "D", "C", "N"] 12 | 13 | # Create a dictionary for one-hot encoding 14 | one_hot_dict = {e: [1.0 if e == ec else 0.0 for ec in emotion_codes] for e in emotion_codes} 15 | 16 | # Filter out rows with undefined EmoClass 17 | df = df[df['EmoClass'].isin(emotion_codes)] 18 | 19 | # Apply one-hot encoding 20 | for i, e in enumerate(emotion_codes): 21 | df[emotions[i]] = df['EmoClass'].apply(lambda x: one_hot_dict[x][i]) 22 | 23 | # Select relevant columns for the new CSV 24 | df_final = df[['FileName', *emotions, 'Split_Set']] 25 | 26 | # Save the processed data to a new CSV file 27 | df_final.to_csv('processed_labels.csv', index=False) 28 | 29 | print("Processing complete. New file saved as 'processed_labels.csv'") 30 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | audioread==3.0.0 2 | brotlipy==0.7.0 3 | decorator==5.1.1 4 | filelock==3.12.4 5 | fsspec==2023.9.1 6 | huggingface-hub==0.17.2 7 | joblib==1.3.2 8 | lazy_loader==0.3 9 | librosa==0.10.1 10 | llvmlite==0.40.1 11 | msgpack==1.0.5 12 | numba==0.57.1 13 | numpy==1.24.4 14 | packaging==23.1 15 | pandas==2.1.0 16 | Pillow==9.4.0 17 | platformdirs==3.10.0 18 | pooch==1.7.0 19 | python-dateutil==2.8.2 20 | pytz==2023.3.post1 21 | PyYAML==6.0.1 22 | regex==2023.8.8 23 | safetensors==0.3.3 24 | scikit-learn==1.3.0 25 | scipy==1.11.2 26 | six==1.16.0 27 | soundfile==0.12.1 28 | soxr==0.3.6 29 | threadpoolctl==3.2.0 30 | tokenizers==0.13.3 31 | torch==1.13.1 32 | torchaudio==0.13.1 33 | torchvision==0.14.1 34 | tqdm==4.66.1 35 | transformers==4.33.2 36 | tzdata==2023.3 37 | -------------------------------------------------------------------------------- /run_arousal.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ssl_type=wavlm-large 3 | 4 | # Train 5 | pool_type=AttentiveStatisticsPooling 6 | for seed in 7; do 7 | python train_ft_aro_dim_ser.py \ 8 | --seed=${seed} \ 9 | --ssl_type=${ssl_type} \ 10 | --batch_size=32 \ 11 | --accumulation_steps=4 \ 12 | --lr=1e-5 \ 13 | --epochs=20 \ 14 | --pooling_type=${pool_type} \ 15 | --model_path=model/dim_aro_ser/wavLM_adamW/${seed} || exit 0; 16 | 17 | python eval_aro_dim_ser.py \ 18 | --ssl_type=${ssl_type} \ 19 | --pooling_type=${pool_type} \ 20 | --model_path=model/dim_aro_ser/wavLM_adamW/${seed} \ 21 | --store_path=result/dim_aro_ser/wavLM_adamW/${seed}.txt || exit 0; 22 | 23 | done 24 | -------------------------------------------------------------------------------- /run_cat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ssl_type=wavlm-large 3 | 4 | # Train 5 | pool_type=AttentiveStatisticsPooling 6 | for seed in 7; do 7 | python train_ft_cat_ser_weighted.py \ 8 | --seed=${seed} \ 9 | --ssl_type=${ssl_type} \ 10 | --batch_size=32 \ 11 | --accumulation_steps=4 \ 12 | --lr=1e-5 \ 13 | --epochs=20 \ 14 | --pooling_type=${pool_type} \ 15 | --model_path=model/weight_cat_ser/w2v_adamW/${seed} || exit 0; 16 | 17 | python eval_cat_ser_weighted.py \ 18 | --ssl_type=${ssl_type} \ 19 | --pooling_type=${pool_type} \ 20 | --model_path=model/weight_cat_ser/w2v_adamW/${seed} \ 21 | --store_path=result/weight_cat_ser/w2v_adamW/${seed}.txt || exit 0; 22 | 23 | done 24 | -------------------------------------------------------------------------------- /run_dim.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ssl_type=wavlm-large 3 | 4 | # Train 5 | pool_type=AttentiveStatisticsPooling 6 | for seed in 7; do 7 | python train_ft_dim_ser.py \ 8 | --seed=${seed} \ 9 | --ssl_type=${ssl_type} \ 10 | --batch_size=32 \ 11 | --accumulation_steps=4 \ 12 | --lr=1e-5 \ 13 | --epochs=20 \ 14 | --pooling_type=${pool_type} \ 15 | --model_path=model/dim_ser/wavLM_adamW/${seed} || exit 0; 16 | 17 | # Evaluation on Test3 and save results using format required by challenge 18 | python eval_dim_ser_test3.py \ 19 | --ssl_type=${ssl_type} \ 20 | --pooling_type=${pool_type} \ 21 | --model_path=model/dim_ser/wavLM_adamW/${seed} \ 22 | --store_path=result/dim_ser/wavLM_adamW/${seed}.txt || exit 0; 23 | 24 | # General evaluation code for sets with labels 25 | python eval_dim_ser.py \ 26 | --ssl_type=${ssl_type} \ 27 | --pooling_type=${pool_type} \ 28 | --model_path=model/dim_ser/wavLM_adamW/${seed} \ 29 | --store_path=result/dim_ser/wavLM_adamW/${seed}.txt || exit 0; 30 | 31 | done 32 | -------------------------------------------------------------------------------- /run_dominance.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ssl_type=wavlm-large 3 | 4 | # Train 5 | pool_type=AttentiveStatisticsPooling 6 | for seed in 7; do 7 | python train_ft_dom_dim_ser.py \ 8 | --seed=${seed} \ 9 | --ssl_type=${ssl_type} \ 10 | --batch_size=32 \ 11 | --accumulation_steps=4 \ 12 | --lr=1e-5 \ 13 | --epochs=20 \ 14 | --pooling_type=${pool_type} \ 15 | --model_path=model/dim_dom_ser/wavLM_adamW/${seed} || exit 0; 16 | 17 | python eval_dom_dim_ser.py \ 18 | --ssl_type=${ssl_type} \ 19 | --pooling_type=${pool_type} \ 20 | --model_path=model/dim_dom_ser/wavLM_adamW/${seed} \ 21 | --store_path=result/dim_dom_ser/wavLM_adamW/${seed}.txt || exit 0; 22 | 23 | done 24 | -------------------------------------------------------------------------------- /run_valence.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ssl_type=wavlm-large 3 | 4 | # Train 5 | pool_type=AttentiveStatisticsPooling 6 | for seed in 7; do 7 | python train_ft_val_dim_ser.py \ 8 | --seed=${seed} \ 9 | --ssl_type=${ssl_type} \ 10 | --batch_size=32 \ 11 | --accumulation_steps=4 \ 12 | --lr=1e-5 \ 13 | --epochs=20 \ 14 | --pooling_type=${pool_type} \ 15 | --model_path=model/dim_val_ser/wavLM_adamW/${seed} || exit 0; 16 | 17 | python eval_val_dim_ser.py \ 18 | --ssl_type=${ssl_type} \ 19 | --pooling_type=${pool_type} \ 20 | --model_path=model/dim_val_ser/wavLM_adamW/${seed} \ 21 | --store_path=result/dim_val_ser/wavLM_adamW/${seed}.txt || exit 0; 22 | 23 | done 24 | -------------------------------------------------------------------------------- /spec-file.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | @EXPLICIT 5 | https://repo.anaconda.com/pkgs/main/linux-64/_libgcc_mutex-0.1-main.conda 6 | https://repo.anaconda.com/pkgs/main/linux-64/blas-1.0-mkl.conda 7 | https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2022.12.7-ha878542_0.conda 8 | https://repo.anaconda.com/pkgs/main/linux-64/intel-openmp-2021.4.0-h06a4308_3561.conda 9 | https://repo.anaconda.com/pkgs/main/linux-64/ld_impl_linux-64-2.35.1-h7274673_9.conda 10 | https://repo.anaconda.com/pkgs/main/linux-64/libstdcxx-ng-9.3.0-hd4cf53a_17.conda 11 | https://conda.anaconda.org/pytorch/noarch/pytorch-mutex-1.0-cuda.tar.bz2 12 | https://repo.anaconda.com/pkgs/main/noarch/tzdata-2022a-hda174b7_0.conda 13 | https://repo.anaconda.com/pkgs/main/linux-64/libgomp-9.3.0-h5101ec6_17.conda 14 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-2021.4.0-h06a4308_640.conda 15 | https://repo.anaconda.com/pkgs/main/linux-64/_openmp_mutex-4.5-1_gnu.tar.bz2 16 | https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-12_linux64_mkl.tar.bz2 17 | https://repo.anaconda.com/pkgs/main/linux-64/libgcc-ng-9.3.0-h5101ec6_17.conda 18 | https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-12_linux64_mkl.tar.bz2 19 | https://repo.anaconda.com/pkgs/main/linux-64/bzip2-1.0.8-h7b6447c_0.conda 20 | https://repo.anaconda.com/pkgs/main/linux-64/cudatoolkit-11.3.1-h2bc3f7f_2.conda 21 | https://repo.anaconda.com/pkgs/main/linux-64/giflib-5.2.1-h7b6447c_0.conda 22 | https://repo.anaconda.com/pkgs/main/linux-64/gmp-6.2.1-h295c915_3.conda 23 | https://repo.anaconda.com/pkgs/main/linux-64/jpeg-9e-h7f8727e_0.conda 24 | https://repo.anaconda.com/pkgs/main/linux-64/lame-3.100-h7b6447c_0.conda 25 | https://repo.anaconda.com/pkgs/main/linux-64/libffi-3.3-he6710b0_2.conda 26 | https://repo.anaconda.com/pkgs/main/linux-64/libiconv-1.16-h7f8727e_2.conda 27 | https://repo.anaconda.com/pkgs/main/linux-64/libtasn1-4.16.0-h27cfd23_0.conda 28 | https://repo.anaconda.com/pkgs/main/linux-64/libunistring-0.9.10-h27cfd23_0.conda 29 | https://repo.anaconda.com/pkgs/main/linux-64/libwebp-base-1.2.2-h7f8727e_0.conda 30 | https://repo.anaconda.com/pkgs/main/linux-64/lz4-c-1.9.3-h295c915_1.conda 31 | https://repo.anaconda.com/pkgs/main/linux-64/ncurses-6.3-h7f8727e_2.conda 32 | https://repo.anaconda.com/pkgs/main/linux-64/openh264-2.1.1-h4ff587b_0.conda 33 | https://repo.anaconda.com/pkgs/main/linux-64/openssl-1.1.1t-h7f8727e_0.conda 34 | https://repo.anaconda.com/pkgs/main/linux-64/xz-5.2.5-h7b6447c_0.conda 35 | https://repo.anaconda.com/pkgs/main/linux-64/yaml-0.2.5-h7b6447c_0.conda 36 | https://repo.anaconda.com/pkgs/main/linux-64/zlib-1.2.11-h7f8727e_4.conda 37 | https://conda.anaconda.org/conda-forge/linux-64/libfaiss-1.7.1-cuda112h5bea7ad_1_cuda.tar.bz2 38 | https://conda.anaconda.org/conda-forge/linux-64/libfaiss-avx2-1.7.1-cuda112h1234567_1_cuda.tar.bz2 39 | https://repo.anaconda.com/pkgs/main/linux-64/libidn2-2.3.2-h7f8727e_0.conda 40 | https://repo.anaconda.com/pkgs/main/linux-64/libpng-1.6.37-hbc83047_0.conda 41 | https://repo.anaconda.com/pkgs/main/linux-64/nettle-3.7.3-hbbd107a_1.conda 42 | https://repo.anaconda.com/pkgs/main/linux-64/readline-8.1.2-h7f8727e_1.conda 43 | https://repo.anaconda.com/pkgs/main/linux-64/tk-8.6.11-h1ccaba5_0.conda 44 | https://repo.anaconda.com/pkgs/main/linux-64/zstd-1.4.9-haebb681_0.conda 45 | https://repo.anaconda.com/pkgs/main/linux-64/freetype-2.11.0-h70c0345_0.conda 46 | https://repo.anaconda.com/pkgs/main/linux-64/gnutls-3.6.15-he1e5248_0.conda 47 | https://repo.anaconda.com/pkgs/main/linux-64/libtiff-4.2.0-h85742a9_0.conda 48 | https://repo.anaconda.com/pkgs/main/linux-64/sqlite-3.38.2-hc218d9a_0.conda 49 | https://conda.anaconda.org/pytorch/linux-64/ffmpeg-4.3-hf484d3e_0.tar.bz2 50 | https://repo.anaconda.com/pkgs/main/linux-64/lcms2-2.12-h3be6417_0.conda 51 | https://repo.anaconda.com/pkgs/main/linux-64/libwebp-1.2.2-h55f646e_0.conda 52 | https://repo.anaconda.com/pkgs/main/linux-64/python-3.9.7-h12debd9_1.conda 53 | https://conda.anaconda.org/conda-forge/noarch/certifi-2022.12.7-pyhd8ed1ab_0.conda 54 | https://repo.anaconda.com/pkgs/main/noarch/charset-normalizer-2.0.4-pyhd3eb1b0_0.conda 55 | https://repo.anaconda.com/pkgs/main/noarch/colorama-0.4.4-pyhd3eb1b0_0.conda 56 | https://repo.anaconda.com/pkgs/main/noarch/idna-3.3-pyhd3eb1b0_0.conda 57 | https://repo.anaconda.com/pkgs/main/linux-64/pillow-9.0.1-py39h22f2fdc_0.conda 58 | https://repo.anaconda.com/pkgs/main/linux-64/pycosat-0.6.3-py39h27cfd23_0.conda 59 | https://repo.anaconda.com/pkgs/main/noarch/pycparser-2.21-pyhd3eb1b0_0.conda 60 | https://repo.anaconda.com/pkgs/main/linux-64/pysocks-1.7.1-py39h06a4308_0.conda 61 | https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.9-2_cp39.tar.bz2 62 | https://repo.anaconda.com/pkgs/main/linux-64/ruamel_yaml-0.15.100-py39h27cfd23_0.conda 63 | https://repo.anaconda.com/pkgs/main/noarch/six-1.16.0-pyhd3eb1b0_1.conda 64 | https://repo.anaconda.com/pkgs/main/noarch/toolz-0.11.2-pyhd3eb1b0_0.conda 65 | https://repo.anaconda.com/pkgs/main/linux-64/typing_extensions-4.3.0-py39h06a4308_0.conda 66 | https://repo.anaconda.com/pkgs/main/noarch/wheel-0.37.1-pyhd3eb1b0_0.conda 67 | https://repo.anaconda.com/pkgs/main/linux-64/cffi-1.15.0-py39hd667e15_1.conda 68 | https://repo.anaconda.com/pkgs/main/linux-64/cytoolz-0.11.0-py39h27cfd23_0.conda 69 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-service-2.4.0-py39h7f8727e_0.conda 70 | https://conda.anaconda.org/pytorch/linux-64/pytorch-1.12.1-py3.9_cuda11.3_cudnn8.3.2_0.tar.bz2 71 | https://repo.anaconda.com/pkgs/main/linux-64/setuptools-58.0.4-py39h06a4308_0.conda 72 | https://repo.anaconda.com/pkgs/main/noarch/tqdm-4.63.0-pyhd3eb1b0_0.conda 73 | https://repo.anaconda.com/pkgs/main/linux-64/brotlipy-0.7.0-py39h27cfd23_1003.conda 74 | https://repo.anaconda.com/pkgs/main/linux-64/conda-package-handling-1.8.1-py39h7f8727e_0.conda 75 | https://repo.anaconda.com/pkgs/main/linux-64/cryptography-36.0.0-py39h9ce1e76_0.conda 76 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-base-1.21.5-py39hf524024_2.conda 77 | https://repo.anaconda.com/pkgs/main/linux-64/pip-21.2.4-py39h06a4308_0.conda 78 | https://repo.anaconda.com/pkgs/main/noarch/pyopenssl-22.0.0-pyhd3eb1b0_0.conda 79 | https://repo.anaconda.com/pkgs/main/noarch/urllib3-1.26.8-pyhd3eb1b0_0.conda 80 | https://repo.anaconda.com/pkgs/main/noarch/requests-2.27.1-pyhd3eb1b0_0.conda 81 | https://repo.anaconda.com/pkgs/main/linux-64/conda-4.14.0-py39h06a4308_0.conda 82 | https://conda.anaconda.org/conda-forge/linux-64/faiss-1.7.1-py39cuda112h5ca99f2_1_cuda.tar.bz2 83 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_fft-1.3.1-py39hd3c417c_0.conda 84 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_random-1.2.2-py39h51133e4_0.conda 85 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-1.21.5-py39he7a7128_2.conda 86 | https://conda.anaconda.org/pytorch/linux-64/torchaudio-0.12.1-py39_cu113.tar.bz2 87 | https://conda.anaconda.org/pytorch/linux-64/torchvision-0.13.1-py39_cu113.tar.bz2 88 | -------------------------------------------------------------------------------- /train_ft_aro_dim_ser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Local modules 3 | import os 4 | import sys 5 | import argparse 6 | # 3rd-Party Modules 7 | import numpy as np 8 | import pickle as pk 9 | import pandas as pd 10 | from tqdm import tqdm 11 | import glob 12 | import librosa 13 | import copy 14 | 15 | # PyTorch Modules 16 | import torch 17 | import torch.nn as nn 18 | import torch.optim as optim 19 | import torch.nn.functional as F 20 | from torch.utils.data import ConcatDataset, DataLoader 21 | 22 | import torch.optim as optim 23 | from transformers import AutoModel 24 | import importlib 25 | # Self-Written Modules 26 | sys.path.append(os.getcwd()) 27 | import net 28 | import utils 29 | 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--seed", type=int, default=100) 33 | parser.add_argument("--ssl_type", type=str, default="wavlm-large") 34 | parser.add_argument("--batch_size", type=int, default=32) 35 | parser.add_argument("--accumulation_steps", type=int, default=1) 36 | parser.add_argument("--epochs", type=int, default=10) 37 | parser.add_argument("--lr", type=float, default=0.001) 38 | parser.add_argument("--model_path", type=str, default="./temp") 39 | parser.add_argument("--head_dim", type=int, default=1024) 40 | 41 | parser.add_argument("--pooling_type", type=str, default="MeanPooling") 42 | args = parser.parse_args() 43 | 44 | utils.set_deterministic(args.seed) 45 | SSL_TYPE = utils.get_ssl_type(args.ssl_type) 46 | assert SSL_TYPE != None, print("Invalid SSL type!") 47 | BATCH_SIZE = args.batch_size 48 | ACCUMULATION_STEP = args.accumulation_steps 49 | assert (ACCUMULATION_STEP > 0) and (BATCH_SIZE % ACCUMULATION_STEP == 0) 50 | EPOCHS=args.epochs 51 | LR=args.lr 52 | MODEL_PATH = args.model_path 53 | 54 | 55 | import json 56 | from collections import defaultdict 57 | config_path = "config.json" 58 | with open(config_path, "r") as f: 59 | config = json.load(f) 60 | audio_path = config["wav_dir"] 61 | label_path = config["label_path"] 62 | 63 | total_dataset=dict() 64 | total_dataloader=dict() 65 | for dtype in ["train", "dev"]: 66 | cur_utts, cur_labs = utils.load_adv_arousal(label_path, dtype) 67 | cur_wavs = utils.load_audio(audio_path, cur_utts) 68 | if dtype == "train": 69 | cur_wav_set = utils.WavSet(cur_wavs) 70 | cur_wav_set.save_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 71 | else: 72 | if dtype == "dev": 73 | wav_mean = total_dataset["train"].datasets[0].wav_mean 74 | wav_std = total_dataset["train"].datasets[0].wav_std 75 | elif dtype == "test": 76 | wav_mean, wav_std = utils.load_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 77 | cur_wav_set = utils.WavSet(cur_wavs, wav_mean=wav_mean, wav_std=wav_std) 78 | ######################################################## 79 | cur_bs = BATCH_SIZE // ACCUMULATION_STEP if dtype == "train" else 1 80 | is_shuffle=True if dtype == "train" else False 81 | ######################################################## 82 | cur_emo_set = utils.ADV_EmoSet(cur_labs) 83 | total_dataset[dtype] = utils.CombinedSet([cur_wav_set, cur_emo_set, cur_utts]) 84 | total_dataloader[dtype] = DataLoader( 85 | total_dataset[dtype], batch_size=cur_bs, shuffle=is_shuffle, 86 | pin_memory=True, num_workers=4, 87 | collate_fn=utils.collate_fn_wav_lab_mask 88 | ) 89 | 90 | print("Loading pre-trained ", SSL_TYPE, " model...") 91 | 92 | ssl_model = AutoModel.from_pretrained(SSL_TYPE) 93 | ssl_model.freeze_feature_encoder() 94 | ssl_model.eval(); ssl_model.cuda() 95 | 96 | ########## Implement pooling method ########## 97 | feat_dim = ssl_model.config.hidden_size 98 | 99 | pool_net = getattr(net, args.pooling_type) 100 | attention_pool_type_list = ["AttentiveStatisticsPooling"] 101 | if args.pooling_type in attention_pool_type_list: 102 | is_attentive_pooling = True 103 | pool_model = pool_net(feat_dim) 104 | else: 105 | is_attentive_pooling = False 106 | pool_model = pool_net() 107 | print(pool_model) 108 | pool_model.cuda() 109 | concat_pool_type_list = ["AttentiveStatisticsPooling"] 110 | dh_input_dim = feat_dim * 2 \ 111 | if args.pooling_type in concat_pool_type_list \ 112 | else feat_dim 113 | 114 | ser_model = net.EmotionRegression(dh_input_dim, args.head_dim, 1, 1, dropout=0.5) 115 | ############################################## 116 | ser_model.eval(); ser_model.cuda() 117 | 118 | ssl_opt = torch.optim.AdamW(ssl_model.parameters(), LR) 119 | ser_opt = torch.optim.AdamW(ser_model.parameters(), LR) 120 | 121 | # scaler = GradScaler() 122 | ssl_opt.zero_grad(set_to_none=True) 123 | ser_opt.zero_grad(set_to_none=True) 124 | 125 | if is_attentive_pooling: 126 | pool_opt = torch.optim.AdamW(pool_model.parameters(), LR) 127 | pool_opt.zero_grad(set_to_none=True) 128 | 129 | lm = utils.LogManager() 130 | lm.alloc_stat_type_list(["train_aro"]) 131 | lm.alloc_stat_type_list(["dev_aro"]) 132 | 133 | min_epoch=0 134 | min_loss=1e10 135 | 136 | for epoch in range(EPOCHS): 137 | print("Epoch: ", epoch) 138 | lm.init_stat() 139 | ssl_model.train() 140 | pool_model.train() 141 | ser_model.train() 142 | batch_cnt = 0 143 | 144 | for xy_pair in tqdm(total_dataloader["train"]): 145 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 146 | y = xy_pair[1]; y=y.cuda(non_blocking=True).float() 147 | mask = xy_pair[2]; mask=mask.cuda(non_blocking=True).float() 148 | 149 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state # (B, T, 1024) 150 | ssl = pool_model(ssl, mask) 151 | 152 | emo_pred = ser_model(ssl) 153 | 154 | ccc = utils.CCC_loss(emo_pred, y) 155 | loss = 1.0 - ccc 156 | total_loss = torch.sum(loss) / ACCUMULATION_STEP 157 | total_loss.backward() 158 | if (batch_cnt+1) % ACCUMULATION_STEP == 0 or (batch_cnt+1) == len(total_dataloader["train"]): 159 | ssl_opt.step() 160 | ser_opt.step() 161 | if is_attentive_pooling: 162 | pool_opt.step() 163 | ssl_opt.zero_grad(set_to_none=True) 164 | ser_opt.zero_grad(set_to_none=True) 165 | if is_attentive_pooling: 166 | pool_opt.zero_grad(set_to_none=True) 167 | batch_cnt += 1 168 | 169 | # Logging 170 | lm.add_torch_stat("train_aro", ccc[0]) 171 | 172 | 173 | ssl_model.eval() 174 | pool_model.eval() 175 | ser_model.eval() 176 | total_pred = [] 177 | total_y = [] 178 | for xy_pair in tqdm(total_dataloader["dev"]): 179 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 180 | y = xy_pair[1]; y=y.cuda(non_blocking=True).float() 181 | mask = xy_pair[2]; mask=mask.cuda(non_blocking=True).float() 182 | 183 | with torch.no_grad(): 184 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state 185 | ssl = pool_model(ssl, mask) 186 | emo_pred = ser_model(ssl) 187 | 188 | total_pred.append(emo_pred) 189 | total_y.append(y) 190 | 191 | # CCC calculation 192 | total_pred = torch.cat(total_pred, 0) 193 | total_y = torch.cat(total_y, 0) 194 | ccc = utils.CCC_loss(total_pred, total_y) 195 | # Logging 196 | lm.add_torch_stat("dev_aro", ccc[0]) 197 | 198 | 199 | # Save model 200 | lm.print_stat() 201 | 202 | 203 | 204 | dev_loss = 1.0 - lm.get_stat("dev_aro") 205 | if min_loss > dev_loss: 206 | min_epoch = epoch 207 | min_loss = dev_loss 208 | 209 | print("Save",min_epoch) 210 | print("Loss",3.0-min_loss) 211 | save_model_list = ["ser", "ssl"] 212 | if is_attentive_pooling: 213 | save_model_list.append("pool") 214 | 215 | torch.save(ser_model.state_dict(), \ 216 | os.path.join(MODEL_PATH, "final_ser.pt")) 217 | torch.save(ssl_model.state_dict(), \ 218 | os.path.join(MODEL_PATH, "final_ssl.pt")) 219 | if is_attentive_pooling: 220 | torch.save(pool_model.state_dict(), \ 221 | os.path.join(MODEL_PATH, "final_pool.pt")) -------------------------------------------------------------------------------- /train_ft_cat_ser_weighted.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Local modules 3 | import os 4 | import sys 5 | import argparse 6 | # 3rd-Party Modules 7 | import numpy as np 8 | import pickle as pk 9 | import pandas as pd 10 | from tqdm import tqdm 11 | import glob 12 | import librosa 13 | import copy 14 | 15 | # PyTorch Modules 16 | import torch 17 | import torch.nn as nn 18 | import torch.optim as optim 19 | import torch.nn.functional as F 20 | from torch.utils.data import ConcatDataset, DataLoader 21 | import torch.optim as optim 22 | from transformers import AutoModel 23 | import importlib 24 | # Self-Written Modules 25 | sys.path.append(os.getcwd()) 26 | import net 27 | import utils 28 | 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--seed", type=int, default=100) 32 | parser.add_argument("--ssl_type", type=str, default="wavlm-large") 33 | parser.add_argument("--batch_size", type=int, default=32) 34 | parser.add_argument("--accumulation_steps", type=int, default=1) 35 | parser.add_argument("--epochs", type=int, default=10) 36 | parser.add_argument("--lr", type=float, default=0.001) 37 | parser.add_argument("--model_path", type=str, default="./temp") 38 | parser.add_argument("--head_dim", type=int, default=1024) 39 | 40 | parser.add_argument("--pooling_type", type=str, default="MeanPooling") 41 | args = parser.parse_args() 42 | 43 | utils.set_deterministic(args.seed) 44 | SSL_TYPE = utils.get_ssl_type(args.ssl_type) 45 | assert SSL_TYPE != None, print("Invalid SSL type!") 46 | BATCH_SIZE = args.batch_size 47 | ACCUMULATION_STEP = args.accumulation_steps 48 | assert (ACCUMULATION_STEP > 0) and (BATCH_SIZE % ACCUMULATION_STEP == 0) 49 | EPOCHS=args.epochs 50 | LR=args.lr 51 | MODEL_PATH = args.model_path 52 | 53 | 54 | import json 55 | from collections import defaultdict 56 | config_path = "config_cat.json" 57 | with open(config_path, "r") as f: 58 | config = json.load(f) 59 | audio_path = config["wav_dir"] 60 | label_path = config["label_path"] 61 | 62 | import pandas as pd 63 | import numpy as np 64 | 65 | # Load the CSV file 66 | df = pd.read_csv(label_path) 67 | 68 | # Filter out only 'Train' samples 69 | train_df = df[df['Split_Set'] == 'Train'] 70 | 71 | # Classes (emotions) 72 | classes = ['Angry', 'Sad', 'Happy', 'Surprise', 'Fear', 'Disgust', 'Contempt', 'Neutral'] 73 | 74 | # Calculate class frequencies 75 | class_frequencies = train_df[classes].sum().to_dict() 76 | 77 | # Total number of samples 78 | total_samples = len(train_df) 79 | 80 | # Calculate class weights 81 | class_weights = {cls: total_samples / (len(classes) * freq) if freq != 0 else 0 for cls, freq in class_frequencies.items()} 82 | 83 | print(class_weights) 84 | 85 | # Convert to list in the order of classes 86 | weights_list = [class_weights[cls] for cls in classes] 87 | 88 | # Convert to PyTorch tensor 89 | class_weights_tensor = torch.tensor(weights_list, device='cuda', dtype=torch.float) 90 | 91 | 92 | # Print or return the tensor 93 | print(class_weights_tensor) 94 | 95 | 96 | total_dataset=dict() 97 | total_dataloader=dict() 98 | for dtype in ["train", "dev"]: 99 | cur_utts, cur_labs = utils.load_cat_emo_label(label_path, dtype) 100 | cur_wavs = utils.load_audio(audio_path, cur_utts) 101 | if dtype == "train": 102 | cur_wav_set = utils.WavSet(cur_wavs) 103 | cur_wav_set.save_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 104 | else: 105 | if dtype == "dev": 106 | wav_mean = total_dataset["train"].datasets[0].wav_mean 107 | wav_std = total_dataset["train"].datasets[0].wav_std 108 | elif dtype == "test": 109 | wav_mean, wav_std = utils.load_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 110 | cur_wav_set = utils.WavSet(cur_wavs, wav_mean=wav_mean, wav_std=wav_std) 111 | ######################################################## 112 | cur_bs = BATCH_SIZE // ACCUMULATION_STEP if dtype == "train" else 1 113 | is_shuffle=True if dtype == "train" else False 114 | ######################################################## 115 | cur_emo_set = utils.CAT_EmoSet(cur_labs) 116 | total_dataset[dtype] = utils.CombinedSet([cur_wav_set, cur_emo_set, cur_utts]) 117 | total_dataloader[dtype] = DataLoader( 118 | total_dataset[dtype], batch_size=cur_bs, shuffle=is_shuffle, 119 | pin_memory=True, num_workers=4, 120 | collate_fn=utils.collate_fn_wav_lab_mask 121 | ) 122 | 123 | print("Loading pre-trained ", SSL_TYPE, " model...") 124 | 125 | ssl_model = AutoModel.from_pretrained(SSL_TYPE) 126 | ssl_model.freeze_feature_encoder() 127 | ssl_model.eval(); ssl_model.cuda() 128 | 129 | ########## Implement pooling method ########## 130 | feat_dim = ssl_model.config.hidden_size 131 | 132 | pool_net = getattr(net, args.pooling_type) 133 | attention_pool_type_list = ["AttentiveStatisticsPooling"] 134 | if args.pooling_type in attention_pool_type_list: 135 | is_attentive_pooling = True 136 | pool_model = pool_net(feat_dim) 137 | else: 138 | is_attentive_pooling = False 139 | pool_model = pool_net() 140 | print(pool_model) 141 | pool_model.cuda() 142 | concat_pool_type_list = ["AttentiveStatisticsPooling"] 143 | dh_input_dim = feat_dim * 2 \ 144 | if args.pooling_type in concat_pool_type_list \ 145 | else feat_dim 146 | 147 | ser_model = net.EmotionRegression(dh_input_dim, args.head_dim, 1, 8, dropout=0.5) 148 | ############################################## 149 | ser_model.eval(); ser_model.cuda() 150 | 151 | ssl_opt = torch.optim.AdamW(ssl_model.parameters(), LR) 152 | ser_opt = torch.optim.AdamW(ser_model.parameters(), LR) 153 | 154 | # scaler = GradScaler() 155 | ssl_opt.zero_grad(set_to_none=True) 156 | ser_opt.zero_grad(set_to_none=True) 157 | 158 | if is_attentive_pooling: 159 | pool_opt = torch.optim.AdamW(pool_model.parameters(), LR) 160 | pool_opt.zero_grad(set_to_none=True) 161 | 162 | lm = utils.LogManager() 163 | lm.alloc_stat_type_list(["train_loss"]) 164 | lm.alloc_stat_type_list(["dev_loss"]) 165 | 166 | min_epoch=0 167 | min_loss=1e10 168 | 169 | for epoch in range(EPOCHS): 170 | print("Epoch: ", epoch) 171 | lm.init_stat() 172 | ssl_model.train() 173 | pool_model.train() 174 | ser_model.train() 175 | batch_cnt = 0 176 | 177 | for xy_pair in tqdm(total_dataloader["train"]): 178 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 179 | y = xy_pair[1]; y=y.max(dim=1)[1]; y=y.cuda(non_blocking=True).long() 180 | mask = xy_pair[2]; mask=mask.cuda(non_blocking=True).float() 181 | 182 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state # (B, T, 1024) 183 | ssl = pool_model(ssl, mask) 184 | 185 | emo_pred = ser_model(ssl) 186 | 187 | loss = utils.CE_weight_category(emo_pred, y, class_weights_tensor) 188 | 189 | total_loss = loss / ACCUMULATION_STEP 190 | total_loss.backward() 191 | if (batch_cnt+1) % ACCUMULATION_STEP == 0 or (batch_cnt+1) == len(total_dataloader["train"]): 192 | 193 | ssl_opt.step() 194 | 195 | ser_opt.step() 196 | 197 | if is_attentive_pooling: 198 | 199 | pool_opt.step() 200 | 201 | ssl_opt.zero_grad(set_to_none=True) 202 | ser_opt.zero_grad(set_to_none=True) 203 | if is_attentive_pooling: 204 | pool_opt.zero_grad(set_to_none=True) 205 | batch_cnt += 1 206 | 207 | # Logging 208 | lm.add_torch_stat("train_loss", loss) 209 | 210 | 211 | ssl_model.eval() 212 | pool_model.eval() 213 | ser_model.eval() 214 | total_pred = [] 215 | total_y = [] 216 | for xy_pair in tqdm(total_dataloader["dev"]): 217 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 218 | y = xy_pair[1]; y=y.max(dim=1)[1]; y=y.cuda(non_blocking=True).long() 219 | mask = xy_pair[2]; mask=mask.cuda(non_blocking=True).float() 220 | 221 | with torch.no_grad(): 222 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state 223 | ssl = pool_model(ssl, mask) 224 | emo_pred = ser_model(ssl) 225 | 226 | total_pred.append(emo_pred) 227 | total_y.append(y) 228 | 229 | # CCC calculation 230 | total_pred = torch.cat(total_pred, 0) 231 | total_y = torch.cat(total_y, 0) 232 | loss = utils.CE_weight_category(emo_pred, y, class_weights_tensor) 233 | # Logging 234 | lm.add_torch_stat("dev_loss", loss) 235 | 236 | 237 | # Save model 238 | lm.print_stat() 239 | 240 | 241 | dev_loss = lm.get_stat("dev_loss") 242 | if min_loss > dev_loss: 243 | min_epoch = epoch 244 | min_loss = dev_loss 245 | 246 | print("Save",min_epoch) 247 | print("Loss",min_loss) 248 | save_model_list = ["ser", "ssl"] 249 | if is_attentive_pooling: 250 | save_model_list.append("pool") 251 | 252 | 253 | torch.save(ser_model.state_dict(), \ 254 | os.path.join(MODEL_PATH, "final_ser.pt")) 255 | torch.save(ssl_model.state_dict(), \ 256 | os.path.join(MODEL_PATH, "final_ssl.pt")) 257 | if is_attentive_pooling: 258 | torch.save(pool_model.state_dict(), \ 259 | os.path.join(MODEL_PATH, "final_pool.pt")) -------------------------------------------------------------------------------- /train_ft_dim_ser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Local modules 3 | import os 4 | import sys 5 | import argparse 6 | # 3rd-Party Modules 7 | import numpy as np 8 | import pickle as pk 9 | import pandas as pd 10 | from tqdm import tqdm 11 | import glob 12 | import librosa 13 | import copy 14 | 15 | # PyTorch Modules 16 | import torch 17 | import torch.nn as nn 18 | import torch.optim as optim 19 | import torch.nn.functional as F 20 | from torch.utils.data import ConcatDataset, DataLoader 21 | from torch.cuda.amp import GradScaler, autocast 22 | import torch.optim as optim 23 | from transformers import AutoModel 24 | import importlib 25 | # Self-Written Modules 26 | sys.path.append(os.getcwd()) 27 | import net 28 | import utils 29 | 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--seed", type=int, default=100) 33 | parser.add_argument("--ssl_type", type=str, default="wavlm-large") 34 | parser.add_argument("--batch_size", type=int, default=32) 35 | parser.add_argument("--accumulation_steps", type=int, default=1) 36 | parser.add_argument("--epochs", type=int, default=10) 37 | parser.add_argument("--lr", type=float, default=0.001) 38 | parser.add_argument("--model_path", type=str, default="./temp") 39 | parser.add_argument("--head_dim", type=int, default=1024) 40 | 41 | parser.add_argument("--pooling_type", type=str, default="MeanPooling") 42 | args = parser.parse_args() 43 | 44 | utils.set_deterministic(args.seed) 45 | SSL_TYPE = utils.get_ssl_type(args.ssl_type) 46 | assert SSL_TYPE != None, print("Invalid SSL type!") 47 | BATCH_SIZE = args.batch_size 48 | ACCUMULATION_STEP = args.accumulation_steps 49 | assert (ACCUMULATION_STEP > 0) and (BATCH_SIZE % ACCUMULATION_STEP == 0) 50 | EPOCHS=args.epochs 51 | LR=args.lr 52 | MODEL_PATH = args.model_path 53 | 54 | 55 | import json 56 | from collections import defaultdict 57 | config_path = "config.json" 58 | with open(config_path, "r") as f: 59 | config = json.load(f) 60 | audio_path = config["wav_dir"] 61 | label_path = config["label_path"] 62 | 63 | total_dataset=dict() 64 | total_dataloader=dict() 65 | for dtype in ["train", "dev"]: 66 | cur_utts, cur_labs = utils.load_adv_emo_label(label_path, dtype) 67 | cur_wavs = utils.load_audio(audio_path, cur_utts) 68 | if dtype == "train": 69 | cur_wav_set = utils.WavSet(cur_wavs) 70 | cur_wav_set.save_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 71 | else: 72 | if dtype == "dev": 73 | wav_mean = total_dataset["train"].datasets[0].wav_mean 74 | wav_std = total_dataset["train"].datasets[0].wav_std 75 | elif dtype == "test": 76 | wav_mean, wav_std = utils.load_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 77 | cur_wav_set = utils.WavSet(cur_wavs, wav_mean=wav_mean, wav_std=wav_std) 78 | ######################################################## 79 | cur_bs = BATCH_SIZE // ACCUMULATION_STEP if dtype == "train" else 1 80 | is_shuffle=True if dtype == "train" else False 81 | ######################################################## 82 | cur_emo_set = utils.ADV_EmoSet(cur_labs) 83 | total_dataset[dtype] = utils.CombinedSet([cur_wav_set, cur_emo_set, cur_utts]) 84 | total_dataloader[dtype] = DataLoader( 85 | total_dataset[dtype], batch_size=cur_bs, shuffle=is_shuffle, 86 | pin_memory=True, num_workers=4, 87 | collate_fn=utils.collate_fn_wav_lab_mask 88 | ) 89 | 90 | print("Loading pre-trained ", SSL_TYPE, " model...") 91 | 92 | ssl_model = AutoModel.from_pretrained(SSL_TYPE) 93 | ssl_model.freeze_feature_encoder() 94 | ssl_model.eval(); ssl_model.cuda() 95 | 96 | ########## Implement pooling method ########## 97 | feat_dim = ssl_model.config.hidden_size 98 | 99 | pool_net = getattr(net, args.pooling_type) 100 | attention_pool_type_list = ["AttentiveStatisticsPooling"] 101 | if args.pooling_type in attention_pool_type_list: 102 | is_attentive_pooling = True 103 | pool_model = pool_net(feat_dim) 104 | else: 105 | is_attentive_pooling = False 106 | pool_model = pool_net() 107 | print(pool_model) 108 | pool_model.cuda() 109 | concat_pool_type_list = ["AttentiveStatisticsPooling"] 110 | dh_input_dim = feat_dim * 2 \ 111 | if args.pooling_type in concat_pool_type_list \ 112 | else feat_dim 113 | 114 | ser_model = net.EmotionRegression(dh_input_dim, args.head_dim, 1, 3, dropout=0.5) 115 | ############################################## 116 | ser_model.eval(); ser_model.cuda() 117 | 118 | ssl_opt = torch.optim.AdamW(ssl_model.parameters(), LR) 119 | ser_opt = torch.optim.AdamW(ser_model.parameters(), LR) 120 | 121 | scaler = GradScaler() 122 | ssl_opt.zero_grad(set_to_none=True) 123 | ser_opt.zero_grad(set_to_none=True) 124 | 125 | if is_attentive_pooling: 126 | pool_opt = torch.optim.AdamW(pool_model.parameters(), LR) 127 | pool_opt.zero_grad(set_to_none=True) 128 | 129 | lm = utils.LogManager() 130 | lm.alloc_stat_type_list(["train_aro", "train_dom", "train_val"]) 131 | lm.alloc_stat_type_list(["dev_aro", "dev_dom", "dev_val"]) 132 | 133 | min_epoch=0 134 | min_loss=1e10 135 | 136 | for epoch in range(EPOCHS): 137 | print("Epoch: ", epoch) 138 | lm.init_stat() 139 | ssl_model.train() 140 | pool_model.train() 141 | ser_model.train() 142 | batch_cnt = 0 143 | 144 | for xy_pair in tqdm(total_dataloader["train"]): 145 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 146 | y = xy_pair[1]; y=y.cuda(non_blocking=True).float() 147 | mask = xy_pair[2]; mask=mask.cuda(non_blocking=True).float() 148 | 149 | with autocast(enabled=True): 150 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state # (B, T, 1024) 151 | ssl = pool_model(ssl, mask) 152 | 153 | emo_pred = ser_model(ssl) 154 | ccc = utils.CCC_loss(emo_pred, y) 155 | loss = 1.0 - ccc 156 | total_loss = torch.sum(loss) / ACCUMULATION_STEP 157 | scaler.scale(total_loss).backward() 158 | if (batch_cnt+1) % ACCUMULATION_STEP == 0 or (batch_cnt+1) == len(total_dataloader["train"]): 159 | scaler.step(ssl_opt) 160 | scaler.step(ser_opt) 161 | if is_attentive_pooling: 162 | scaler.step(pool_opt) 163 | scaler.update() 164 | ssl_opt.zero_grad(set_to_none=True) 165 | ser_opt.zero_grad(set_to_none=True) 166 | if is_attentive_pooling: 167 | pool_opt.zero_grad(set_to_none=True) 168 | batch_cnt += 1 169 | 170 | # Logging 171 | lm.add_torch_stat("train_aro", ccc[0]) 172 | lm.add_torch_stat("train_dom", ccc[1]) 173 | lm.add_torch_stat("train_val", ccc[2]) 174 | 175 | ssl_model.eval() 176 | pool_model.eval() 177 | ser_model.eval() 178 | total_pred = [] 179 | total_y = [] 180 | for xy_pair in tqdm(total_dataloader["dev"]): 181 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 182 | y = xy_pair[1]; y=y.cuda(non_blocking=True).float() 183 | mask = xy_pair[2]; mask=mask.cuda(non_blocking=True).float() 184 | 185 | with torch.no_grad(): 186 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state 187 | ssl = pool_model(ssl, mask) 188 | emo_pred = ser_model(ssl) 189 | 190 | total_pred.append(emo_pred) 191 | total_y.append(y) 192 | 193 | # CCC calculation 194 | total_pred = torch.cat(total_pred, 0) 195 | total_y = torch.cat(total_y, 0) 196 | ccc = utils.CCC_loss(total_pred, total_y) 197 | # Logging 198 | lm.add_torch_stat("dev_aro", ccc[0]) 199 | lm.add_torch_stat("dev_dom", ccc[1]) 200 | lm.add_torch_stat("dev_val", ccc[2]) 201 | 202 | # Save model 203 | lm.print_stat() 204 | 205 | dev_loss = 3.0 - lm.get_stat("dev_aro") - lm.get_stat("dev_dom") - lm.get_stat("dev_val") 206 | if min_loss > dev_loss: 207 | min_epoch = epoch 208 | min_loss = dev_loss 209 | 210 | print("Save",min_epoch) 211 | print("Loss",3.0-min_loss) 212 | save_model_list = ["ser", "ssl"] 213 | if is_attentive_pooling: 214 | save_model_list.append("pool") 215 | 216 | torch.save(ser_model.state_dict(), \ 217 | os.path.join(MODEL_PATH, "final_ser.pt")) 218 | torch.save(ssl_model.state_dict(), \ 219 | os.path.join(MODEL_PATH, "final_ssl.pt")) 220 | if is_attentive_pooling: 221 | torch.save(pool_model.state_dict(), \ 222 | os.path.join(MODEL_PATH, "final_pool.pt")) -------------------------------------------------------------------------------- /train_ft_dom_dim_ser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Local modules 3 | import os 4 | import sys 5 | import argparse 6 | # 3rd-Party Modules 7 | import numpy as np 8 | import pickle as pk 9 | import pandas as pd 10 | from tqdm import tqdm 11 | import glob 12 | import librosa 13 | import copy 14 | 15 | # PyTorch Modules 16 | import torch 17 | import torch.nn as nn 18 | import torch.optim as optim 19 | import torch.nn.functional as F 20 | from torch.utils.data import ConcatDataset, DataLoader 21 | import torch.optim as optim 22 | from transformers import AutoModel 23 | import importlib 24 | # Self-Written Modules 25 | sys.path.append(os.getcwd()) 26 | import net 27 | import utils 28 | 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--seed", type=int, default=100) 32 | parser.add_argument("--ssl_type", type=str, default="wavlm-large") 33 | parser.add_argument("--batch_size", type=int, default=32) 34 | parser.add_argument("--accumulation_steps", type=int, default=1) 35 | parser.add_argument("--epochs", type=int, default=10) 36 | parser.add_argument("--lr", type=float, default=0.001) 37 | parser.add_argument("--model_path", type=str, default="./temp") 38 | parser.add_argument("--head_dim", type=int, default=1024) 39 | 40 | parser.add_argument("--pooling_type", type=str, default="MeanPooling") 41 | args = parser.parse_args() 42 | 43 | utils.set_deterministic(args.seed) 44 | SSL_TYPE = utils.get_ssl_type(args.ssl_type) 45 | assert SSL_TYPE != None, print("Invalid SSL type!") 46 | BATCH_SIZE = args.batch_size 47 | ACCUMULATION_STEP = args.accumulation_steps 48 | assert (ACCUMULATION_STEP > 0) and (BATCH_SIZE % ACCUMULATION_STEP == 0) 49 | EPOCHS=args.epochs 50 | LR=args.lr 51 | MODEL_PATH = args.model_path 52 | 53 | 54 | import json 55 | from collections import defaultdict 56 | config_path = "config.json" 57 | with open(config_path, "r") as f: 58 | config = json.load(f) 59 | audio_path = config["wav_dir"] 60 | label_path = config["label_path"] 61 | 62 | total_dataset=dict() 63 | total_dataloader=dict() 64 | for dtype in ["train", "dev"]: 65 | cur_utts, cur_labs = utils.load_adv_dominance(label_path, dtype) 66 | cur_wavs = utils.load_audio(audio_path, cur_utts) 67 | if dtype == "train": 68 | cur_wav_set = utils.WavSet(cur_wavs) 69 | cur_wav_set.save_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 70 | else: 71 | if dtype == "dev": 72 | wav_mean = total_dataset["train"].datasets[0].wav_mean 73 | wav_std = total_dataset["train"].datasets[0].wav_std 74 | elif dtype == "test": 75 | wav_mean, wav_std = utils.load_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 76 | cur_wav_set = utils.WavSet(cur_wavs, wav_mean=wav_mean, wav_std=wav_std) 77 | ######################################################## 78 | cur_bs = BATCH_SIZE // ACCUMULATION_STEP if dtype == "train" else 1 79 | is_shuffle=True if dtype == "train" else False 80 | ######################################################## 81 | cur_emo_set = utils.ADV_EmoSet(cur_labs) 82 | total_dataset[dtype] = utils.CombinedSet([cur_wav_set, cur_emo_set, cur_utts]) 83 | total_dataloader[dtype] = DataLoader( 84 | total_dataset[dtype], batch_size=cur_bs, shuffle=is_shuffle, 85 | pin_memory=True, num_workers=4, 86 | collate_fn=utils.collate_fn_wav_lab_mask 87 | ) 88 | 89 | print("Loading pre-trained ", SSL_TYPE, " model...") 90 | 91 | ssl_model = AutoModel.from_pretrained(SSL_TYPE) 92 | ssl_model.freeze_feature_encoder() 93 | ssl_model.eval(); ssl_model.cuda() 94 | 95 | ########## Implement pooling method ########## 96 | feat_dim = ssl_model.config.hidden_size 97 | 98 | pool_net = getattr(net, args.pooling_type) 99 | attention_pool_type_list = ["AttentiveStatisticsPooling"] 100 | if args.pooling_type in attention_pool_type_list: 101 | is_attentive_pooling = True 102 | pool_model = pool_net(feat_dim) 103 | else: 104 | is_attentive_pooling = False 105 | pool_model = pool_net() 106 | print(pool_model) 107 | pool_model.cuda() 108 | concat_pool_type_list = ["AttentiveStatisticsPooling"] 109 | dh_input_dim = feat_dim * 2 \ 110 | if args.pooling_type in concat_pool_type_list \ 111 | else feat_dim 112 | 113 | ser_model = net.EmotionRegression(dh_input_dim, args.head_dim, 1, 1, dropout=0.5) 114 | ############################################## 115 | ser_model.eval(); ser_model.cuda() 116 | 117 | ssl_opt = torch.optim.AdamW(ssl_model.parameters(), LR) 118 | ser_opt = torch.optim.AdamW(ser_model.parameters(), LR) 119 | 120 | ssl_opt.zero_grad(set_to_none=True) 121 | ser_opt.zero_grad(set_to_none=True) 122 | 123 | if is_attentive_pooling: 124 | pool_opt = torch.optim.AdamW(pool_model.parameters(), LR) 125 | pool_opt.zero_grad(set_to_none=True) 126 | 127 | lm = utils.LogManager() 128 | lm.alloc_stat_type_list(["train_dom"]) 129 | lm.alloc_stat_type_list(["dev_dom"]) 130 | 131 | min_epoch=0 132 | min_loss=1e10 133 | 134 | for epoch in range(EPOCHS): 135 | print("Epoch: ", epoch) 136 | lm.init_stat() 137 | ssl_model.train() 138 | pool_model.train() 139 | ser_model.train() 140 | batch_cnt = 0 141 | 142 | for xy_pair in tqdm(total_dataloader["train"]): 143 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 144 | y = xy_pair[1]; y=y.cuda(non_blocking=True).float() 145 | mask = xy_pair[2]; mask=mask.cuda(non_blocking=True).float() 146 | 147 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state # (B, T, 1024) 148 | ssl = pool_model(ssl, mask) 149 | 150 | emo_pred = ser_model(ssl) 151 | 152 | ccc = utils.CCC_loss(emo_pred, y) 153 | loss = 1.0 - ccc 154 | total_loss = torch.sum(loss) / ACCUMULATION_STEP 155 | total_loss.backward() 156 | if (batch_cnt+1) % ACCUMULATION_STEP == 0 or (batch_cnt+1) == len(total_dataloader["train"]): 157 | 158 | ssl_opt.step() 159 | ser_opt.step() 160 | 161 | if is_attentive_pooling: 162 | 163 | pool_opt.step() 164 | 165 | ssl_opt.zero_grad(set_to_none=True) 166 | ser_opt.zero_grad(set_to_none=True) 167 | if is_attentive_pooling: 168 | pool_opt.zero_grad(set_to_none=True) 169 | batch_cnt += 1 170 | 171 | # Logging 172 | lm.add_torch_stat("train_dom", ccc[0]) 173 | 174 | 175 | ssl_model.eval() 176 | pool_model.eval() 177 | ser_model.eval() 178 | total_pred = [] 179 | total_y = [] 180 | for xy_pair in tqdm(total_dataloader["dev"]): 181 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 182 | y = xy_pair[1]; y=y.cuda(non_blocking=True).float() 183 | mask = xy_pair[2]; mask=mask.cuda(non_blocking=True).float() 184 | 185 | with torch.no_grad(): 186 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state 187 | ssl = pool_model(ssl, mask) 188 | emo_pred = ser_model(ssl) 189 | 190 | total_pred.append(emo_pred) 191 | total_y.append(y) 192 | 193 | # CCC calculation 194 | total_pred = torch.cat(total_pred, 0) 195 | total_y = torch.cat(total_y, 0) 196 | ccc = utils.CCC_loss(total_pred, total_y) 197 | # Logging 198 | lm.add_torch_stat("dev_dom", ccc[0]) 199 | 200 | 201 | # Save model 202 | lm.print_stat() 203 | 204 | 205 | dev_loss = 1.0 - lm.get_stat("dev_dom") 206 | if min_loss > dev_loss: 207 | min_epoch = epoch 208 | min_loss = dev_loss 209 | 210 | print("Save",min_epoch) 211 | print("Loss",3.0-min_loss) 212 | save_model_list = ["ser", "ssl"] 213 | if is_attentive_pooling: 214 | save_model_list.append("pool") 215 | 216 | torch.save(ser_model.state_dict(), \ 217 | os.path.join(MODEL_PATH, "final_ser.pt")) 218 | torch.save(ssl_model.state_dict(), \ 219 | os.path.join(MODEL_PATH, "final_ssl.pt")) 220 | if is_attentive_pooling: 221 | torch.save(pool_model.state_dict(), \ 222 | os.path.join(MODEL_PATH, "final_pool.pt")) -------------------------------------------------------------------------------- /train_ft_val_dim_ser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # Local modules 3 | import os 4 | import sys 5 | import argparse 6 | # 3rd-Party Modules 7 | import numpy as np 8 | import pickle as pk 9 | import pandas as pd 10 | from tqdm import tqdm 11 | import glob 12 | import librosa 13 | import copy 14 | 15 | # PyTorch Modules 16 | import torch 17 | import torch.nn as nn 18 | import torch.optim as optim 19 | import torch.nn.functional as F 20 | from torch.utils.data import ConcatDataset, DataLoader 21 | import torch.optim as optim 22 | from transformers import AutoModel 23 | import importlib 24 | 25 | # Self-Written Modules 26 | sys.path.append(os.getcwd()) 27 | import net 28 | import utils 29 | 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--seed", type=int, default=100) 33 | parser.add_argument("--ssl_type", type=str, default="wavlm-large") 34 | parser.add_argument("--batch_size", type=int, default=32) 35 | parser.add_argument("--accumulation_steps", type=int, default=1) 36 | parser.add_argument("--epochs", type=int, default=10) 37 | parser.add_argument("--lr", type=float, default=0.001) 38 | parser.add_argument("--model_path", type=str, default="./temp") 39 | parser.add_argument("--head_dim", type=int, default=1024) 40 | 41 | parser.add_argument("--pooling_type", type=str, default="MeanPooling") 42 | args = parser.parse_args() 43 | 44 | utils.set_deterministic(args.seed) 45 | SSL_TYPE = utils.get_ssl_type(args.ssl_type) 46 | assert SSL_TYPE != None, print("Invalid SSL type!") 47 | BATCH_SIZE = args.batch_size 48 | ACCUMULATION_STEP = args.accumulation_steps 49 | assert (ACCUMULATION_STEP > 0) and (BATCH_SIZE % ACCUMULATION_STEP == 0) 50 | EPOCHS=args.epochs 51 | LR=args.lr 52 | MODEL_PATH = args.model_path 53 | 54 | 55 | import json 56 | from collections import defaultdict 57 | config_path = "config.json" 58 | with open(config_path, "r") as f: 59 | config = json.load(f) 60 | audio_path = config["wav_dir"] 61 | label_path = config["label_path"] 62 | 63 | total_dataset=dict() 64 | total_dataloader=dict() 65 | for dtype in ["train", "dev"]: 66 | cur_utts, cur_labs = utils.load_adv_valence(label_path, dtype) 67 | cur_wavs = utils.load_audio(audio_path, cur_utts) 68 | if dtype == "train": 69 | cur_wav_set = utils.WavSet(cur_wavs) 70 | cur_wav_set.save_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 71 | else: 72 | if dtype == "dev": 73 | wav_mean = total_dataset["train"].datasets[0].wav_mean 74 | wav_std = total_dataset["train"].datasets[0].wav_std 75 | elif dtype == "test": 76 | wav_mean, wav_std = utils.load_norm_stat(MODEL_PATH+"/train_norm_stat.pkl") 77 | cur_wav_set = utils.WavSet(cur_wavs, wav_mean=wav_mean, wav_std=wav_std) 78 | ######################################################## 79 | cur_bs = BATCH_SIZE // ACCUMULATION_STEP if dtype == "train" else 1 80 | is_shuffle=True if dtype == "train" else False 81 | ######################################################## 82 | cur_emo_set = utils.ADV_EmoSet(cur_labs) 83 | total_dataset[dtype] = utils.CombinedSet([cur_wav_set, cur_emo_set, cur_utts]) 84 | total_dataloader[dtype] = DataLoader( 85 | total_dataset[dtype], batch_size=cur_bs, shuffle=is_shuffle, 86 | pin_memory=True, num_workers=4, 87 | collate_fn=utils.collate_fn_wav_lab_mask 88 | ) 89 | 90 | print("Loading pre-trained ", SSL_TYPE, " model...") 91 | 92 | ssl_model = AutoModel.from_pretrained(SSL_TYPE) 93 | ssl_model.freeze_feature_encoder() 94 | ssl_model.eval(); ssl_model.cuda() 95 | 96 | ########## Implement pooling method ########## 97 | feat_dim = ssl_model.config.hidden_size 98 | 99 | pool_net = getattr(net, args.pooling_type) 100 | attention_pool_type_list = ["AttentiveStatisticsPooling"] 101 | if args.pooling_type in attention_pool_type_list: 102 | is_attentive_pooling = True 103 | pool_model = pool_net(feat_dim) 104 | else: 105 | is_attentive_pooling = False 106 | pool_model = pool_net() 107 | print(pool_model) 108 | pool_model.cuda() 109 | concat_pool_type_list = ["AttentiveStatisticsPooling"] 110 | dh_input_dim = feat_dim * 2 \ 111 | if args.pooling_type in concat_pool_type_list \ 112 | else feat_dim 113 | 114 | ser_model = net.EmotionRegression(dh_input_dim, args.head_dim, 1, 1, dropout=0.5) 115 | ############################################## 116 | ser_model.eval(); ser_model.cuda() 117 | 118 | ssl_opt = torch.optim.AdamW(ssl_model.parameters(), LR) 119 | ser_opt = torch.optim.AdamW(ser_model.parameters(), LR) 120 | 121 | ssl_opt.zero_grad(set_to_none=True) 122 | ser_opt.zero_grad(set_to_none=True) 123 | 124 | if is_attentive_pooling: 125 | pool_opt = torch.optim.AdamW(pool_model.parameters(), LR) 126 | pool_opt.zero_grad(set_to_none=True) 127 | 128 | lm = utils.LogManager() 129 | lm.alloc_stat_type_list(["train_val"]) 130 | lm.alloc_stat_type_list(["dev_val"]) 131 | 132 | min_epoch=0 133 | min_loss=1e10 134 | 135 | for epoch in range(EPOCHS): 136 | print("Epoch: ", epoch) 137 | lm.init_stat() 138 | ssl_model.train() 139 | pool_model.train() 140 | ser_model.train() 141 | batch_cnt = 0 142 | 143 | for xy_pair in tqdm(total_dataloader["train"]): 144 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 145 | y = xy_pair[1]; y=y.cuda(non_blocking=True).float() 146 | mask = xy_pair[2]; mask=mask.cuda(non_blocking=True).float() 147 | 148 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state # (B, T, 1024) 149 | ssl = pool_model(ssl, mask) 150 | 151 | emo_pred = ser_model(ssl) 152 | 153 | ccc = utils.CCC_loss(emo_pred, y) 154 | loss = 1.0 - ccc 155 | total_loss = torch.sum(loss) / ACCUMULATION_STEP 156 | total_loss.backward() 157 | if (batch_cnt+1) % ACCUMULATION_STEP == 0 or (batch_cnt+1) == len(total_dataloader["train"]): 158 | ssl_opt.step() 159 | ser_opt.step() 160 | 161 | if is_attentive_pooling: 162 | pool_opt.step() 163 | 164 | ssl_opt.zero_grad(set_to_none=True) 165 | ser_opt.zero_grad(set_to_none=True) 166 | if is_attentive_pooling: 167 | pool_opt.zero_grad(set_to_none=True) 168 | batch_cnt += 1 169 | 170 | # Logging 171 | lm.add_torch_stat("train_val", ccc[0]) 172 | 173 | ssl_model.eval() 174 | pool_model.eval() 175 | ser_model.eval() 176 | total_pred = [] 177 | total_y = [] 178 | for xy_pair in tqdm(total_dataloader["dev"]): 179 | x = xy_pair[0]; x=x.cuda(non_blocking=True).float() 180 | y = xy_pair[1]; y=y.cuda(non_blocking=True).float() 181 | mask = xy_pair[2]; mask=mask.cuda(non_blocking=True).float() 182 | 183 | with torch.no_grad(): 184 | ssl = ssl_model(x, attention_mask=mask).last_hidden_state 185 | ssl = pool_model(ssl, mask) 186 | emo_pred = ser_model(ssl) 187 | 188 | total_pred.append(emo_pred) 189 | total_y.append(y) 190 | 191 | # CCC calculation 192 | total_pred = torch.cat(total_pred, 0) 193 | total_y = torch.cat(total_y, 0) 194 | ccc = utils.CCC_loss(total_pred, total_y) 195 | # Logging 196 | lm.add_torch_stat("dev_val", ccc[0]) 197 | 198 | # Save model 199 | lm.print_stat() 200 | 201 | 202 | dev_loss = 1.0 - lm.get_stat("dev_val") 203 | if min_loss > dev_loss: 204 | min_epoch = epoch 205 | min_loss = dev_loss 206 | 207 | print("Save",min_epoch) 208 | print("Loss",3.0-min_loss) 209 | save_model_list = ["ser", "ssl"] 210 | if is_attentive_pooling: 211 | save_model_list.append("pool") 212 | 213 | torch.save(ser_model.state_dict(), \ 214 | os.path.join(MODEL_PATH, "final_ser.pt")) 215 | torch.save(ssl_model.state_dict(), \ 216 | os.path.join(MODEL_PATH, "final_ssl.pt")) 217 | if is_attentive_pooling: 218 | torch.save(pool_model.state_dict(), \ 219 | os.path.join(MODEL_PATH, "final_pool.pt")) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * 2 | from .dataset import * 3 | from .loss_manager import * 4 | from .etc import * 5 | -------------------------------------------------------------------------------- /utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .wav import * 2 | from .podcast import * 3 | -------------------------------------------------------------------------------- /utils/data/podcast.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | SPLIT_MAP = { 5 | "train": "Train", 6 | "dev": "Development", 7 | "test1": "Test1", 8 | "test2": "Test2", 9 | "test3": "Test3" 10 | } 11 | 12 | # Load label 13 | def load_utts(label_path, dtype): 14 | label_df = pd.read_csv(label_path, sep=",") 15 | cur_df = label_df[label_df["Split_Set"] == SPLIT_MAP[dtype]] 16 | cur_utts = cur_df["FileName"].to_numpy() 17 | 18 | return cur_utts 19 | 20 | def load_adv_emo_label(label_path, dtype): 21 | label_df = pd.read_csv(label_path, sep=",") 22 | cur_df = label_df[label_df["Split_Set"] == SPLIT_MAP[dtype]] 23 | cur_utts = cur_df["FileName"].to_numpy() 24 | cur_labs = cur_df[["EmoAct", "EmoDom", "EmoVal"]].to_numpy() 25 | 26 | return cur_utts, cur_labs 27 | 28 | def load_adv_arousal(label_path, dtype): 29 | label_df = pd.read_csv(label_path, sep=",") 30 | cur_df = label_df[label_df["Split_Set"] == SPLIT_MAP[dtype]] 31 | cur_utts = cur_df["FileName"].to_numpy() 32 | cur_labs = cur_df[["EmoAct"]].to_numpy() 33 | 34 | return cur_utts, cur_labs 35 | 36 | def load_adv_valence(label_path, dtype): 37 | label_df = pd.read_csv(label_path, sep=",") 38 | cur_df = label_df[label_df["Split_Set"] == SPLIT_MAP[dtype]] 39 | cur_utts = cur_df["FileName"].to_numpy() 40 | cur_labs = cur_df[["EmoVal"]].to_numpy() 41 | 42 | return cur_utts, cur_labs 43 | 44 | def load_adv_dominance(label_path, dtype): 45 | label_df = pd.read_csv(label_path, sep=",") 46 | cur_df = label_df[label_df["Split_Set"] == SPLIT_MAP[dtype]] 47 | cur_utts = cur_df["FileName"].to_numpy() 48 | cur_labs = cur_df[["EmoDom"]].to_numpy() 49 | 50 | return cur_utts, cur_labs 51 | 52 | def load_cat_emo_label(label_path, dtype): 53 | label_df = pd.read_csv(label_path, sep=",") 54 | cur_df = label_df[label_df["Split_Set"] == SPLIT_MAP[dtype]] 55 | cur_utts = cur_df["FileName"].to_numpy() 56 | cur_labs = cur_df[["Angry", "Sad", "Happy", "Surprise", "Fear", "Disgust", "Contempt", "Neutral"]].to_numpy() 57 | 58 | return cur_utts, cur_labs 59 | 60 | def load_spk_id(label_path, dtype): 61 | label_df = pd.read_csv(label_path, sep=",") 62 | cur_df = label_df[(label_df["Split_Set"] == SPLIT_MAP[dtype])] 63 | cur_df = cur_df[(cur_df["SpkrID"] != "Unknown")] 64 | cur_utts = cur_df["FileName"].to_numpy() 65 | cur_spk_ids = cur_df["SpkrID"].to_numpy().astype(np.int) 66 | # Cleanining speaker id 67 | uniq_spk_id = list(set(cur_spk_ids)) 68 | uniq_spk_id.sort() 69 | for new_id, old_id in enumerate(uniq_spk_id): 70 | cur_spk_ids[cur_spk_ids == old_id] = new_id 71 | total_spk_num = len(uniq_spk_id) 72 | 73 | return cur_utts, cur_spk_ids, total_spk_num -------------------------------------------------------------------------------- /utils/data/wav.py: -------------------------------------------------------------------------------- 1 | import os 2 | import librosa 3 | from tqdm import tqdm 4 | from multiprocessing import Pool 5 | 6 | # Load audio 7 | def extract_wav(wav_path): 8 | raw_wav, _ = librosa.load(wav_path, sr=16000) 9 | return raw_wav 10 | def load_audio(audio_path, utts, nj=24): 11 | # Audio path: directory of audio files 12 | # utts: list of utterance names with .wav extension 13 | wav_paths = [os.path.join(audio_path, utt) for utt in utts] 14 | with Pool(nj) as p: 15 | wavs = list(tqdm(p.imap(extract_wav, wav_paths), total=len(wav_paths))) 16 | return wavs -------------------------------------------------------------------------------- /utils/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * 2 | from .collate_fn import * 3 | from .normalizer import * -------------------------------------------------------------------------------- /utils/dataset/collate_fn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | def collate_fn_wav_lab_mask(batch): 6 | total_wav = [] 7 | total_lab = [] 8 | total_dur = [] 9 | total_utt = [] 10 | for wav_data in batch: 11 | 12 | wav, dur = wav_data[0] 13 | lab = wav_data[1] 14 | total_wav.append(torch.Tensor(wav)) 15 | total_lab.append(lab) 16 | total_dur.append(dur) 17 | total_utt.append(wav_data[2]) 18 | 19 | total_wav = nn.utils.rnn.pad_sequence(total_wav, batch_first=True) 20 | 21 | total_lab = torch.Tensor(np.array(total_lab)) 22 | max_dur = np.max(total_dur) 23 | attention_mask = torch.zeros(total_wav.shape[0], max_dur) 24 | for data_idx, dur in enumerate(total_dur): 25 | attention_mask[data_idx,:dur] = 1 26 | ## compute mask 27 | return total_wav, total_lab, attention_mask, total_utt 28 | 29 | 30 | def collate_fn_wav_test3(batch): 31 | total_wav = [] 32 | total_dur = [] 33 | total_utt = [] 34 | for wav_data in batch: 35 | 36 | wav, dur = wav_data[0] 37 | total_wav.append(torch.Tensor(wav)) 38 | total_dur.append(dur) 39 | total_utt.append(wav_data[1]) 40 | 41 | total_wav = nn.utils.rnn.pad_sequence(total_wav, batch_first=True) 42 | 43 | max_dur = np.max(total_dur) 44 | attention_mask = torch.zeros(total_wav.shape[0], max_dur) 45 | for data_idx, dur in enumerate(total_dur): 46 | attention_mask[data_idx,:dur] = 1 47 | ## compute mask 48 | return total_wav, attention_mask, total_utt 49 | -------------------------------------------------------------------------------- /utils/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle as pk 3 | import torch.utils as torch_utils 4 | from . import normalizer 5 | 6 | """ 7 | All dataset should have the same order based on the utt_list 8 | """ 9 | def load_norm_stat(norm_stat_file): 10 | with open(norm_stat_file, 'rb') as f: 11 | wav_mean, wav_std = pk.load(f) 12 | return wav_mean, wav_std 13 | 14 | 15 | class CombinedSet(torch_utils.data.Dataset): 16 | def __init__(self, *args, **kwargs): 17 | super(CombinedSet, self).__init__() 18 | self.datasets = kwargs.get("datasets", args[0]) 19 | self.data_len = len(self.datasets[0]) 20 | for cur_dataset in self.datasets: 21 | assert len(cur_dataset) == self.data_len, "All dataset should have the same order based on the utt_list" 22 | def __len__(self): 23 | return self.data_len 24 | 25 | def __getitem__(self, idx): 26 | result = [] 27 | for cur_dataset in self.datasets: 28 | result.append(cur_dataset[idx]) 29 | return result 30 | 31 | 32 | class WavSet(torch_utils.data.Dataset): 33 | def __init__(self, *args, **kwargs): 34 | super(WavSet, self).__init__() 35 | self.wav_list = kwargs.get("wav_list", args[0]) # (N, D, T) 36 | 37 | self.wav_mean = kwargs.get("wav_mean", None) 38 | self.wav_std = kwargs.get("wav_std", None) 39 | 40 | self.upper_bound_max_dur = kwargs.get("max_dur", 12) 41 | self.sampling_rate = kwargs.get("sr", 16000) 42 | 43 | # check max duration 44 | self.max_dur = np.min([np.max([len(cur_wav) for cur_wav in self.wav_list]), self.upper_bound_max_dur*self.sampling_rate]) 45 | if self.wav_mean is None or self.wav_std is None: 46 | self.wav_mean, self.wav_std = normalizer. get_norm_stat_for_wav(self.wav_list) 47 | 48 | def save_norm_stat(self, norm_stat_file): 49 | with open(norm_stat_file, 'wb') as f: 50 | pk.dump((self.wav_mean, self.wav_std), f) 51 | 52 | def __len__(self): 53 | return len(self.wav_list) 54 | 55 | def __getitem__(self, idx): 56 | cur_wav = self.wav_list[idx][:self.max_dur] 57 | cur_dur = len(cur_wav) 58 | cur_wav = (cur_wav - self.wav_mean) / (self.wav_std+0.000001) 59 | 60 | result = (cur_wav, cur_dur) 61 | return result 62 | 63 | class ADV_EmoSet(torch_utils.data.Dataset): 64 | def __init__(self, *args, **kwargs): 65 | super(ADV_EmoSet, self).__init__() 66 | self.lab_list = kwargs.get("lab_list", args[0]) 67 | self.max_score = kwargs.get("max_score", 7) 68 | self.min_score = kwargs.get("min_score", 1) 69 | 70 | def __len__(self): 71 | return len(self.lab_list) 72 | 73 | def __getitem__(self, idx): 74 | cur_lab = self.lab_list[idx] 75 | cur_lab = (cur_lab - self.min_score) / (self.max_score-self.min_score) 76 | result = cur_lab 77 | return result 78 | 79 | class CAT_EmoSet(torch_utils.data.Dataset): 80 | def __init__(self, *args, **kwargs): 81 | super(CAT_EmoSet, self).__init__() 82 | self.lab_list = kwargs.get("lab_list", args[0]) 83 | 84 | def __len__(self): 85 | return len(self.lab_list) 86 | 87 | def __getitem__(self, idx): 88 | cur_lab = self.lab_list[idx] 89 | result = cur_lab 90 | return result 91 | 92 | class SpkSet(torch_utils.data.Dataset): 93 | def __init__(self, *args, **kwargs): 94 | super(SpkSet, self).__init__() 95 | self.spk_list = kwargs.get("spk_list", args[0]) 96 | 97 | def __len__(self): 98 | return len(self.spk_list) 99 | 100 | def __getitem__(self, idx): 101 | cur_lab = self.spk_list[idx] 102 | result = cur_lab 103 | return result 104 | 105 | class UttSet(torch_utils.data.Dataset): 106 | def __init__(self, *args, **kwargs): 107 | super(UttSet, self).__init__() 108 | self.utt_list = kwargs.get("utt_list", args[0]) 109 | 110 | def __len__(self): 111 | return len(self.utt_list) 112 | 113 | def __getitem__(self, idx): 114 | cur_lab = self.utt_list[idx] 115 | result = cur_lab 116 | return result 117 | -------------------------------------------------------------------------------- /utils/dataset/normalizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | def get_norm_stat_for_wav(wav_list, verbose=False): 3 | count = 0 4 | wav_sum = 0 5 | wav_sqsum = 0 6 | 7 | for cur_wav in wav_list: 8 | wav_sum += np.sum(cur_wav) 9 | wav_sqsum += np.sum(cur_wav**2) 10 | count += len(cur_wav) 11 | 12 | wav_mean = wav_sum / count 13 | wav_var = (wav_sqsum / count) - (wav_mean**2) 14 | wav_std = np.sqrt(wav_var) 15 | 16 | return wav_mean, wav_std -------------------------------------------------------------------------------- /utils/etc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import time 4 | import numpy as np 5 | import json 6 | def set_deterministic(seed): 7 | torch.manual_seed(seed) 8 | np.random.seed(seed) 9 | 10 | def get_ssl_type(ssl_type): 11 | ssl_book={ 12 | "wav2vec2-large-robust": "facebook/wav2vec2-large-robust", 13 | "wavlm-large": "microsoft/wavlm-large" 14 | } 15 | return ssl_book.get(ssl_type, None) -------------------------------------------------------------------------------- /utils/loss_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import sys 6 | import torch.autograd as autograd 7 | from collections import defaultdict 8 | 9 | class LogManager: 10 | def __init__(self): 11 | self.log_book=defaultdict(lambda: []) 12 | def alloc_stat_type(self, stat_type): 13 | self.log_book[stat_type] = [] 14 | def alloc_stat_type_list(self, stat_type_list): 15 | for stat_type in stat_type_list: 16 | self.alloc_stat_type(stat_type) 17 | def init_stat(self): 18 | for stat_type in self.log_book.keys(): 19 | self.log_book[stat_type] = [] 20 | def add_stat(self, stat_type, stat): 21 | assert stat_type in self.log_book, "Wrong stat type" 22 | self.log_book[stat_type].append(stat) 23 | def add_torch_stat(self, stat_type, stat): 24 | assert stat_type in self.log_book, "Wrong stat type" 25 | self.log_book[stat_type].append(stat.detach().cpu().item()) 26 | def get_stat(self, stat_type): 27 | result_stat = 0 28 | stat_list = self.log_book[stat_type] 29 | if len(stat_list) != 0: 30 | result_stat = np.mean(stat_list) 31 | result_stat = np.round(result_stat, 4) 32 | return result_stat 33 | 34 | def print_stat(self): 35 | for stat_type in self.log_book.keys(): 36 | if len(self.log_book[stat_type]) == 0: 37 | continue 38 | stat = self.get_stat(stat_type) 39 | print(stat_type,":",stat, end=' / ') 40 | print(" ") 41 | 42 | def get_stat_str(self): 43 | result_str = "" 44 | for stat_type in self.log_book.keys(): 45 | if len(self.log_book[stat_type]) == 0: 46 | continue 47 | stat = self.get_stat(stat_type) 48 | result_str += str(stat) + " / " 49 | return result_str 50 | 51 | def CCC_loss(pred, lab, m_lab=None, v_lab=None, is_numpy=False): 52 | """ 53 | pred: (N, 3) 54 | lab: (N, 3) 55 | """ 56 | if is_numpy: 57 | pred = torch.Tensor(pred).float().cuda() 58 | lab = torch.Tensor(lab).float().cuda() 59 | 60 | m_pred = torch.mean(pred, 0, keepdim=True) 61 | m_lab = torch.mean(lab, 0, keepdim=True) 62 | 63 | d_pred = pred - m_pred 64 | d_lab = lab - m_lab 65 | 66 | v_pred = torch.var(pred, 0, unbiased=False) 67 | v_lab = torch.var(lab, 0, unbiased=False) 68 | 69 | corr = torch.sum(d_pred * d_lab, 0) / (torch.sqrt(torch.sum(d_pred ** 2, 0)) * torch.sqrt(torch.sum(d_lab ** 2, 0))) 70 | 71 | s_pred = torch.std(pred, 0, unbiased=False) 72 | s_lab = torch.std(lab, 0, unbiased=False) 73 | 74 | ccc = (2*corr*s_pred*s_lab) / (v_pred + v_lab + (m_pred[0]-m_lab[0])**2) 75 | return ccc 76 | 77 | def MSE_emotion(pred, lab): 78 | aro_loss = F.mse_loss(pred[:][0], lab[:][0]) 79 | dom_loss = F.mse_loss(pred[:][1], lab[:][1]) 80 | val_loss = F.mse_loss(pred[:][2], lab[:][2]) 81 | 82 | return [aro_loss, dom_loss, val_loss] 83 | 84 | 85 | def CE_weight_category(pred, lab, weights): 86 | criterion = torch.nn.CrossEntropyLoss(weight=weights) 87 | return criterion(pred, lab) 88 | 89 | 90 | def calc_err(pred, lab): 91 | p = pred.detach() 92 | t = lab.detach() 93 | total_num = p.size()[0] 94 | ans = torch.argmax(p, dim=1) 95 | corr = torch.sum((ans==t).long()) 96 | 97 | err = (total_num-corr) / total_num 98 | 99 | return err 100 | 101 | def calc_acc(pred, lab): 102 | err = calc_err(pred, lab) 103 | return 1.0 - err 104 | --------------------------------------------------------------------------------