├── LICENSE ├── README.md ├── Data_Preprocessing ├── raw_data_processing.py └── load_files.py └── EmoMA-Net.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 XJTLUSURF20240123 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EmoMA-Net 2 | ## Our Paper 3 | You can access our paper through: [https://dl.acm.org/doi/10.1145/3704289.3704303](https://dl.acm.org/doi/10.1145/3704289.3704303) 4 | 5 | ## EmoMA-Net Model Training and Evaluation 6 | ### Overview 7 | This project aims to train and evaluate a deep learning model for emotion recognition based on the WESAD dataset [1]. The model architecture combines Convolutional Neural Networks (CNN) with Convolutional Block Attention Module (CBAM) and Long Short-Term Memory (LSTM) networks to process time-series data effectively [2, 3, 4]. 8 | ### Dataset 9 | The dataset used for training and evaluation is stored in a CSV file named `merged.csv`. It contains time-series data along with labels representing different emotions. Before training, the dataset is preprocessed to ensure compatibility with the model. 10 | ### Data Preparation 11 | 1. Load and Preprocess Data: 12 | - Load the dataset from the CSV file. 13 | - Change the labels to binary values (0 or 1). 14 | - Select relevant features for training. 15 | - Scale the features using `StandardScaler`. 16 | - Perform Recursive Feature Elimination (RFE) to select the top features. 17 | 2. Split Data: 18 | - Split the data into training and testing sets. 19 | - Use KFold for cross-validation. 20 | ### Model Definition 21 | The model consists of a CNN for feature extraction followed by an LSTM layer for sequence modeling. The model architecture is defined as follows: 22 | - Input Layer: Accepts input data of shape (batch_size, channels, height, width). 23 | - CNN Layers: Multiple convolutional layers to extract features. 24 | - CBAM Layers: Spatial Attention and Channel Attention. 25 | - LSTM Layer: Processes the extracted features over time. 26 | - Output Layer: Produces the final classification output. 27 | ### Training Process 28 | 1. Initialize Model and Optimizer: 29 | - Define the model and optimizer (Adam). 30 | - Set the loss function (CrossEntropyLoss). 31 | 2. Training Loop: 32 | - Train the model over multiple epochs. 33 | - Calculate the loss and accuracy during each epoch. 34 | - Track the best model based on the validation accuracy. 35 | ### Evaluation 36 | 1. Testing Loop: 37 | - Evaluate the model on the test set. 38 | - Calculate the final test accuracy. 39 | - Generate a confusion matrix to assess performance. 40 | ### Running the Code 41 | To run the code, follow these steps: 42 | 1. Install Dependencies: 43 | Ensure you have Python installed. 44 | Install required libraries using pip: 45 | ```python 46 | pip install torch pandas scikit-learn numpy 47 | ``` 48 | 2. Prepare the Dataset: 49 | - Place the merged.csv file in the specified directory. 50 | - Modify the path to the CSV file in the code if necessary. 51 | 3. Execute the Script: 52 | Run the script using Python: 53 | ```python 54 | python EmoMA-Net.py 55 | ``` 56 | ### Results 57 | The script prints out the training progress, including the loss and accuracy at each epoch. After training, it displays the maximum test accuracy and f1-score achieved across all folds of cross-validation. 58 | ### License 59 | This project is licensed under the MIT License - see the LICENSE file for details. 60 | ### Note 61 | Make sure to adjust the file paths and other settings according to your specific environment and requirements. 62 | 63 | ### REFERENCE 64 | [1] Philip Schmidt, Attila Reiss, Robert Duerichen, Claus Marberger, and Kristof Van Laerhoven. 2018. Introducing WESAD, a multimodal dataset for wearable stress and affect detection. Proceedings of the 20th ACM International Conference on Multimodal Interaction (October 2018). DOI:http://dx.doi.org/10.1145/3242969.3242985. 65 | 66 | [2] Sanghyun Woo, Jongchan Park, Joon-Young Lee, and In So Kweon. 2018a. CBAM: Convolutional Block Attention Module. (July 2018). Retrieved July 24, 2024 from https://arxiv.org/abs/1807.06521 67 | 68 | [3] Liu, Y., Kang, J., Li, Y., & Ji, B. (2021). A network intrusion detection method based on CNN and CBAM. IEEE INFOCOM 2021 - IEEE Conference on Computer Communications Workshops (INFOCOM WKSHPS). https://doi.org/10.1109/infocomwkshps51825.2021.9484553 69 | 70 | [4] Maximilian Beck et al. 2024. XLSTM: Extended Long Short-term memory. (May 2024). Retrieved July 24, 2024 from https://arxiv.org/abs/2405.04517 71 | -------------------------------------------------------------------------------- /Data_Preprocessing/raw_data_processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import pandas as pd 5 | import scipy.signal as scisig 6 | import cvxEDA 7 | 8 | # 定义E4(手腕)传感器的采样频率 9 | fs_dict = {'ACC': 32, 'BVP': 64, 'EDA': 4, 'TEMP': 4, 'label': 700, 'Resp': 700} 10 | savePath = 'data' 11 | subject_feature_path = '/subject_feats' 12 | 13 | # 创建存储路径(如果不存在) 14 | if not os.path.exists(savePath): 15 | os.makedirs(savePath) 16 | if not os.path.exists(savePath + subject_feature_path): 17 | os.makedirs(savePath + subject_feature_path) 18 | 19 | 20 | # 定义SubjectData类来处理特定受试者的数据 21 | class SubjectData: 22 | 23 | def __init__(self, main_path, subject_number): 24 | self.name = f'S{subject_number}' 25 | self.subject_keys = ['signal', 'label', 'subject'] 26 | self.signal_keys = ['chest', 'wrist'] 27 | self.chest_keys = ['ACC', 'ECG', 'EMG', 'EDA', 'Temp', 'Resp'] 28 | self.wrist_keys = ['ACC', 'BVP', 'EDA', 'TEMP'] 29 | with open(os.path.join(main_path, self.name) + '/' + self.name + '.pkl', 'rb') as file: 30 | self.data = pickle.load(file, encoding='latin1') 31 | 32 | # 获取手腕数据 33 | def get_wrist_data(self): 34 | data = self.data['signal']['wrist'] 35 | data.update({'Resp': self.data['signal']['chest']['Resp']}) 36 | return data 37 | 38 | # 获取胸部数据 39 | def get_chest_data(self): 40 | return self.data['signal']['chest'] 41 | 42 | 43 | # 定义低通滤波器函数 44 | def butter_lowpass(cutoff, fs, order=5): 45 | nyq = 0.5 * fs 46 | normal_cutoff = cutoff / nyq 47 | b, a = scisig.butter(order, normal_cutoff, btype='low', analog=False) 48 | return b, a 49 | 50 | 51 | # 应用低通滤波器 52 | def butter_lowpass_filter(data, cutoff, fs, order=5): 53 | b, a = butter_lowpass(cutoff, fs, order=order) 54 | y = scisig.lfilter(b, a, data) 55 | return y 56 | 57 | 58 | # 使用FIR滤波器过滤信号 59 | def filterSignalFIR(data, cutoff=0.4, numtaps=64): 60 | f = cutoff / (fs_dict['ACC'] / 2.0) 61 | FIR_coeff = scisig.firwin(numtaps, f) 62 | return scisig.lfilter(FIR_coeff, 1, data) 63 | 64 | 65 | # 处理特定受试者的数据 66 | def process_subject_data(subject_id): 67 | global savePath 68 | 69 | # 创建受试者数据对象 70 | subject = SubjectData(main_path='data/WESAD', subject_number=subject_id) 71 | 72 | # 获取手腕数据(包含呼吸数据) 73 | e4_data_dict = subject.get_wrist_data() 74 | 75 | # 将数据转换为DataFrame 76 | eda_df = pd.DataFrame(e4_data_dict['EDA'], columns=['EDA']) 77 | bvp_df = pd.DataFrame(e4_data_dict['BVP'], columns=['BVP']) 78 | acc_df = pd.DataFrame(e4_data_dict['ACC'], columns=['ACC_x', 'ACC_y', 'ACC_z']) 79 | temp_df = pd.DataFrame(e4_data_dict['TEMP'], columns=['TEMP']) 80 | resp_df = pd.DataFrame(e4_data_dict['Resp'], columns=['Resp']) 81 | 82 | # 应用低通滤波器过滤EDA信号 83 | eda_df['EDA'] = butter_lowpass_filter(eda_df['EDA'], 1.0, fs_dict['EDA'], 6) 84 | 85 | # 应用FIR滤波器过滤加速度信号 86 | for col in acc_df.columns: 87 | acc_df[col] = filterSignalFIR(acc_df[col]) 88 | 89 | # 添加索引以便合并 90 | eda_df.index = [(1 / fs_dict['EDA']) * i for i in range(len(eda_df))] 91 | bvp_df.index = [(1 / fs_dict['BVP']) * i for i in range(len(bvp_df))] 92 | acc_df.index = [(1 / fs_dict['ACC']) * i for i in range(len(acc_df))] 93 | temp_df.index = [(1 / fs_dict['TEMP']) * i for i in range(len(temp_df))] 94 | resp_df.index = [(1 / fs_dict['Resp']) * i for i in range(len(resp_df))] 95 | 96 | # 将索引转换为datetime格式 97 | eda_df.index = pd.to_datetime(eda_df.index, unit='s') 98 | bvp_df.index = pd.to_datetime(bvp_df.index, unit='s') 99 | temp_df.index = pd.to_datetime(temp_df.index, unit='s') 100 | acc_df.index = pd.to_datetime(acc_df.index, unit='s') 101 | resp_df.index = pd.to_datetime(resp_df.index, unit='s') 102 | 103 | # 合并所有数据 104 | df = eda_df.join(bvp_df, how='outer') 105 | df = df.join(temp_df, how='outer') 106 | df = df.join(acc_df, how='outer') 107 | df = df.join(resp_df, how='outer') 108 | 109 | # 保存为CSV文件 110 | df.to_csv(f'{savePath}{subject_feature_path}/S{subject_id}_raw.csv') 111 | 112 | # 清空对象以释放内存 113 | subject = None 114 | 115 | 116 | # 合并所有受试者的数据 117 | def combine_files(subjects): 118 | df_list = [] 119 | for s in subjects: 120 | df = pd.read_csv(f'{savePath}{subject_feature_path}/S{s}_raw.csv', index_col=0) 121 | df['subject'] = s 122 | df_list.append(df) 123 | 124 | combined_df = pd.concat(df_list) 125 | combined_df.reset_index(drop=True, inplace=True) 126 | combined_df.to_csv(f'{savePath}/combined_raw_data.csv') 127 | 128 | counts = combined_df['subject'].value_counts() 129 | print('Number of samples per subject:') 130 | for subject, number in zip(counts.index, counts.values): 131 | print(f'Subject {subject}: {number}') 132 | 133 | 134 | # 主函数 135 | if __name__ == '__main__': 136 | subject_ids = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17] 137 | 138 | for subject_id in subject_ids: 139 | print(f'Processing data for S{subject_id}...') 140 | process_subject_data(subject_id) 141 | 142 | combine_files(subject_ids) 143 | print('Processing complete.') 144 | -------------------------------------------------------------------------------- /Data_Preprocessing/load_files.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import scipy.signal as scisig 3 | import os 4 | import numpy as np 5 | 6 | 7 | def get_user_input(prompt): 8 | try: 9 | return raw_input(prompt) 10 | except NameError: 11 | return input(prompt) 12 | 13 | 14 | def getInputLoadFile(): 15 | '''Asks user for type of file and file path. Loads corresponding data. 16 | 17 | OUTPUT: 18 | data: DataFrame, index is a list of timestamps at 8Hz, columns include 19 | AccelZ, AccelY, AccelX, Temp, EDA, filtered_eda 20 | ''' 21 | print("Please enter information about your EDA file... ") 22 | dataType = 'e4' 23 | if dataType=='q': 24 | filepath = get_user_input("\tFile path: ") 25 | filepath_confirm = filepath 26 | data = loadData_Qsensor(filepath) 27 | elif dataType=='e4': 28 | filepath = get_user_input("\tPath to E4 directory: ") 29 | filepath_confirm = os.path.join(filepath,"EDA.csv") 30 | data = loadData_E4(filepath) 31 | elif dataType=='shimmer': 32 | filepath = get_user_input("\tFile path: ") 33 | filepath_confirm = filepath 34 | data = loadData_shimmer(filepath) 35 | elif dataType=="misc": 36 | filepath = get_user_input("\tFile path: ") 37 | filepath_confirm = filepath 38 | data = loadData_misc(filepath) 39 | else: 40 | print("Error: not a valid file choice") 41 | 42 | return data, filepath_confirm 43 | 44 | def getOutputPath(): 45 | print("") 46 | print("Where would you like to save the computed output file?") 47 | outfile = get_user_input('\tFile name: ') 48 | outputPath = get_user_input('\tFile directory (./ for this directory): ') 49 | fullOutputPath = os.path.join(outputPath,outfile) 50 | if fullOutputPath[-4:] != '.csv': 51 | fullOutputPath = fullOutputPath+'.csv' 52 | return fullOutputPath 53 | 54 | def loadData_Qsensor(filepath): 55 | ''' 56 | This function loads the Q sensor data, uses a lowpass butterworth filter on the EDA signal 57 | Note: currently assumes sampling rate of 8hz, 16hz, 32hz; if sampling rate is 16hz or 32hz the signal is downsampled 58 | 59 | INPUT: 60 | filepath: string, path to input file 61 | 62 | OUTPUT: 63 | data: DataFrame, index is a list of timestamps at 8Hz, columns include AccelZ, AccelY, AccelX, Temp, EDA, filtered_eda 64 | ''' 65 | # Get header info 66 | try: 67 | header_info = pd.io.parsers.read_csv(filepath, nrows=5) 68 | except IOError: 69 | print("Error!! Couldn't load file, make sure the filepath is correct and you are using a csv from the q sensor software\n\n") 70 | return 71 | 72 | # Get sample rate 73 | sampleRate = int((header_info.iloc[3,0]).split(":")[1].strip()) 74 | 75 | # Get the raw data 76 | data = pd.io.parsers.read_csv(filepath, skiprows=7) 77 | data = data.reset_index() 78 | 79 | # Reset the index to be a time and reset the column headers 80 | data.columns = ['AccelZ','AccelY','AccelX','Battery','Temp','EDA'] 81 | 82 | # Get Start Time 83 | startTime = pd.to_datetime(header_info.iloc[4,0][12:-10]) 84 | 85 | # Make sure data has a sample rate of 8Hz 86 | data = interpolateDataTo8Hz(data,sampleRate,startTime) 87 | 88 | # Remove Battery Column 89 | data = data[['AccelZ','AccelY','AccelX','Temp','EDA']] 90 | 91 | # Get the filtered data using a low-pass butterworth filter (cutoff:1hz, fs:8hz, order:6) 92 | data['filtered_eda'] = butter_lowpass_filter(data['EDA'], 1.0, 8, 6) 93 | 94 | return data 95 | 96 | def _loadSingleFile_E4(filepath,list_of_columns, expected_sample_rate,freq): 97 | # Load data 98 | data = pd.read_csv(filepath) 99 | 100 | # Get the startTime and sample rate 101 | startTime = pd.to_datetime(float(data.columns.values[0]),unit="s") 102 | sampleRate = float(data.iloc[0][0]) 103 | data = data[data.index!=0] 104 | data.index = data.index-1 105 | 106 | # Reset the data frame assuming expected_sample_rate 107 | data.columns = list_of_columns 108 | if sampleRate != expected_sample_rate: 109 | print('ERROR, NOT SAMPLED AT {0}HZ. PROBLEMS WILL OCCUR\n'.format(expected_sample_rate)) 110 | 111 | # Make sure data has a sample rate of 8Hz 112 | data = interpolateDataTo8Hz(data,sampleRate,startTime) 113 | 114 | return data 115 | 116 | 117 | def loadData_E4(filepath): 118 | # Load EDA data 119 | eda_data = _loadSingleFile_E4(os.path.join(filepath,'EDA.csv'),["EDA"],4,"250L") 120 | # Get the filtered data using a low-pass butterworth filter (cutoff:1hz, fs:8hz, order:6) 121 | eda_data['filtered_eda'] = butter_lowpass_filter(eda_data['EDA'], 1.0, 8, 6) 122 | 123 | # Load ACC data 124 | acc_data = _loadSingleFile_E4(os.path.join(filepath,'ACC.csv'),["AccelX","AccelY","AccelZ"],32,"31250U") 125 | # Scale the accelometer to +-2g 126 | acc_data[["AccelX","AccelY","AccelZ"]] = acc_data[["AccelX","AccelY","AccelZ"]]/64.0 127 | 128 | # Load Temperature data 129 | temperature_data = _loadSingleFile_E4(os.path.join(filepath,'TEMP.csv'),["Temp"],4,"250L") 130 | 131 | data = eda_data.join(acc_data, how='outer') 132 | data = data.join(temperature_data, how='outer') 133 | 134 | # E4 sometimes records different length files - adjust as necessary 135 | min_length = min(len(acc_data), len(eda_data), len(temperature_data)) 136 | 137 | return data[:min_length] 138 | 139 | def loadData_shimmer(filepath): 140 | data = pd.read_csv(filepath, sep='\t', skiprows=(0,1)) 141 | 142 | orig_cols = data.columns 143 | rename_cols = {} 144 | 145 | for search, new_col in [['Timestamp','Timestamp'], 146 | ['Accel_LN_X', 'AccelX'], ['Accel_LN_Y', 'AccelY'], ['Accel_LN_Z', 'AccelZ'], 147 | ['Skin_Conductance', 'EDA']]: 148 | orig = [c for c in orig_cols if search in c] 149 | if len(orig) == 0: 150 | continue 151 | rename_cols[orig[0]] = new_col 152 | 153 | data.rename(columns=rename_cols, inplace=True) 154 | 155 | # TODO: Assuming no temperature is recorded 156 | data['Temp'] = 0 157 | 158 | # Drop the units row and unnecessary columns 159 | data = data[data['Timestamp'] != 'ms'] 160 | data.index = pd.to_datetime(data['Timestamp'], unit='ms') 161 | data = data[['AccelZ', 'AccelY', 'AccelX', 'Temp', 'EDA']] 162 | 163 | for c in ['AccelZ', 'AccelY', 'AccelX', 'Temp', 'EDA']: 164 | data[c] = pd.to_numeric(data[c]) 165 | 166 | # Convert to 8Hz 167 | data = data.resample("125L").mean() 168 | data.interpolate(inplace=True) 169 | 170 | # Get the filtered data using a low-pass butterworth filter (cutoff:1hz, fs:8hz, order:6) 171 | data['filtered_eda'] = butter_lowpass_filter(data['EDA'], 1.0, 8, 6) 172 | 173 | return data 174 | 175 | 176 | def loadData_getColNames(data_columns): 177 | print("Here are the data columns of your file: ") 178 | print(data_columns) 179 | 180 | # Find the column names for each of the 5 data streams 181 | colnames = ['EDA data','Temperature data','Acceleration X','Acceleration Y','Acceleration Z'] 182 | new_colnames = ['','','','',''] 183 | 184 | for i in range(len(new_colnames)): 185 | new_colnames[i] = get_user_input("Column name that contains "+colnames[i]+": ") 186 | while (new_colnames[i] not in data_columns): 187 | print("Column not found. Please try again") 188 | print("Here are the data columns of your file: ") 189 | print(data_columns) 190 | 191 | new_colnames[i] = get_user_input("Column name that contains "+colnames[i]+": ") 192 | 193 | # Get user input on sample rate 194 | sampleRate = get_user_input("Enter sample rate (must be an integer power of 2): ") 195 | while (sampleRate.isdigit()==False) or (np.log(int(sampleRate))/np.log(2) != np.floor(np.log(int(sampleRate))/np.log(2))): 196 | print("Not an integer power of two") 197 | sampleRate = get_user_input("Enter sample rate (must be a integer power of 2): ") 198 | sampleRate = int(sampleRate) 199 | 200 | # Get user input on start time 201 | startTime = pd.to_datetime(get_user_input("Enter a start time (format: YYYY-MM-DD HH:MM:SS): ")) 202 | while type(startTime)==str: 203 | print("Not a valid date/time") 204 | startTime = pd.to_datetime(get_user_input("Enter a start time (format: YYYY-MM-DD HH:MM:SS): ")) 205 | 206 | 207 | return sampleRate, startTime, new_colnames 208 | 209 | 210 | def loadData_misc(filepath): 211 | # Load data 212 | data = pd.read_csv(filepath) 213 | 214 | # Get the correct colnames 215 | sampleRate, startTime, new_colnames = loadData_getColNames(data.columns.values) 216 | 217 | data.rename(columns=dict(zip(new_colnames,['EDA','Temp','AccelX','AccelY','AccelZ'])), inplace=True) 218 | data = data[['AccelZ','AccelY','AccelX','Temp','EDA']] 219 | 220 | # Make sure data has a sample rate of 8Hz 221 | data = interpolateDataTo8Hz(data,sampleRate,startTime) 222 | 223 | # Get the filtered data using a low-pass butterworth filter (cutoff:1hz, fs:8hz, order:6) 224 | data['filtered_eda'] = butter_lowpass_filter(data['EDA'], 1.0, 8, 6) 225 | 226 | return data 227 | 228 | def interpolateDataTo8Hz(data,sample_rate,startTime): 229 | if sample_rate<8: 230 | # Upsample by linear interpolation 231 | if sample_rate==2: 232 | data.index = pd.date_range(start=startTime, periods=len(data), freq='500L') 233 | elif sample_rate==4: 234 | data.index = pd.date_range(start=startTime, periods=len(data), freq='250L') 235 | data = data.resample("125L").mean() 236 | else: 237 | if sample_rate>8: 238 | # Downsample 239 | idx_range = list(range(0,len(data))) # TODO: double check this one 240 | data = data.iloc[idx_range[0::int(int(sample_rate)/8)]] 241 | # Set the index to be 8Hz 242 | data.index = pd.date_range(start=startTime, periods=len(data), freq='125L') 243 | 244 | # Interpolate all empty values 245 | data = interpolateEmptyValues(data) 246 | return data 247 | 248 | def interpolateEmptyValues(data): 249 | cols = data.columns.values 250 | for c in cols: 251 | data.loc[:, c] = data[c].interpolate() 252 | 253 | return data 254 | 255 | def butter_lowpass(cutoff, fs, order=5): 256 | # Filtering Helper functions 257 | nyq = 0.5 * fs 258 | normal_cutoff = cutoff / nyq 259 | b, a = scisig.butter(order, normal_cutoff, btype='low', analog=False) 260 | return b, a 261 | 262 | def butter_lowpass_filter(data, cutoff, fs, order=5): 263 | # Filtering Helper functions 264 | b, a = butter_lowpass(cutoff, fs, order=order) 265 | y = scisig.lfilter(b, a, data) 266 | return y -------------------------------------------------------------------------------- /EmoMA-Net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.ensemble import RandomForestClassifier 4 | from sklearn.model_selection import train_test_split, KFold 5 | from sklearn.metrics import confusion_matrix 6 | from sklearn.preprocessing import StandardScaler 7 | from sklearn.feature_selection import RFE 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.utils.data import Dataset, DataLoader 12 | import warnings 13 | from sklearn.metrics import f1_score 14 | 15 | warnings.filterwarnings('ignore') 16 | 17 | # Custom dataset class for loading WESADDataset 18 | class WESADDataset(Dataset): 19 | def __init__(self, dataframe): 20 | self.dataframe = dataframe.drop('subject', axis=1) 21 | self.labels = self.dataframe['label'].values 22 | self.dataframe.drop('label', axis=1, inplace=True) 23 | 24 | def __getitem__(self, idx): 25 | x = self.dataframe.iloc[idx].values 26 | x = x.reshape(1, -1) # Adjust the input shape to CNN 27 | y = self.labels[idx] 28 | return torch.Tensor(x), y 29 | 30 | def __len__(self): 31 | return len(self.dataframe) 32 | 33 | # Define the list of features 34 | feats = ['BVP_mean', 'BVP_std', 'BVP_min', 'BVP_max', 35 | 'EDA_phasic_mean', 'EDA_phasic_std', 'EDA_phasic_min', 'EDA_phasic_max', 'EDA_smna_mean', 36 | 'EDA_smna_std', 'EDA_smna_min', 'EDA_smna_max', 'EDA_tonic_mean', 37 | 'EDA_tonic_std', 'EDA_tonic_min', 'EDA_tonic_max', 'Resp_mean', 38 | 'Resp_std', 'Resp_min', 'Resp_max', 'TEMP_mean', 'TEMP_std', 'TEMP_min', 39 | 'TEMP_max', 'TEMP_slope', 'BVP_peak_freq', 'age', 'height', 40 | 'weight', 'subject', 'label'] 41 | 42 | # Function to get data loaders for training and testing 43 | def get_data_loaders(df, train_subjects, test_subjects, train_batch_size=25, test_batch_size=5): 44 | 45 | # Split the training and test sets based on randomly selected people 46 | train_df = df[df['subject'].isin(train_subjects)].reset_index(drop=True) 47 | test_df = df[df['subject'].isin(test_subjects)].reset_index(drop=True) 48 | 49 | # Create data loaders for the training and test sets 50 | train_dset = WESADDataset(train_df) 51 | test_dset = WESADDataset(test_df) 52 | 53 | train_loader = torch.utils.data.DataLoader(train_dset, batch_size=train_batch_size, shuffle=True) 54 | test_loader = torch.utils.data.DataLoader(test_dset, batch_size=test_batch_size) 55 | 56 | return train_loader, test_loader 57 | 58 | # Function to calculate the output size of the convolutional layer 59 | def calculate_conv_output_dim(input_dim, kernel_size, stride, padding): 60 | return (input_dim - kernel_size + 2 * padding) // stride + 1 61 | 62 | # Define the attention mechanism 63 | class Attention(nn.Module): 64 | def __init__(self, lstm_hidden_dim): 65 | super(Attention, self).__init__() 66 | self.attention = nn.Linear(lstm_hidden_dim, 1, bias=False) 67 | 68 | def forward(self, lstm_output): 69 | # Computing attention weights 70 | attn_weights = F.softmax(self.attention(lstm_output), dim=1) 71 | # Weighted average LSTM output using attention weights 72 | attn_output = torch.bmm(attn_weights.transpose(1, 2), lstm_output).squeeze(1) 73 | return attn_output 74 | 75 | # Define the Convolutional Block Attention Module (CBAM) 76 | class CBAM(nn.Module): 77 | def __init__(self, in_channels, reduction_ratio=16): 78 | super(CBAM, self).__init__() 79 | self.channel_gate = nn.Sequential( 80 | nn.AdaptiveAvgPool1d(1), 81 | nn.Conv1d(in_channels, in_channels // reduction_ratio, kernel_size=1), 82 | nn.ReLU(), 83 | nn.Conv1d(in_channels // reduction_ratio, in_channels, kernel_size=1), 84 | nn.Sigmoid() 85 | ) 86 | 87 | self.spatial_gate = nn.Sequential( 88 | nn.Conv1d(in_channels, 1, kernel_size=7, padding=3), 89 | nn.Sigmoid() 90 | ) 91 | 92 | def forward(self, x): 93 | channel_wise = self.channel_gate(x) * x 94 | spatial_wise = self.spatial_gate(x) * x 95 | return channel_wise + spatial_wise 96 | 97 | # Define the CNN-LSTM model with attention and CBAM 98 | class CNNLSTMModel(nn.Module): 99 | def __init__(self, lstm_hidden_dim=50, num_lstm_layers=1): 100 | super(CNNLSTMModel, self).__init__() 101 | self.conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1) 102 | self.cbam1 = CBAM(in_channels=16) 103 | self.pool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0) 104 | self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1) 105 | self.cbam2 = CBAM(in_channels=32) 106 | 107 | # Calculate the dimension of the output of the convolutional layer 108 | conv_output_dim = calculate_conv_output_dim(10, 3, 1, 1) // 4 109 | 110 | # Compute the convolutional layer outputs Compute the dimension of the output size of the convolutional layer 111 | self.conv_output_size = 32 * conv_output_dim 112 | 113 | # Define the LSTM layer 114 | self.lstm = nn.LSTM(input_size=self.conv_output_size, hidden_size=lstm_hidden_dim, num_layers=num_lstm_layers, batch_first=True) 115 | 116 | # Defining the attention layer 117 | self.attention = Attention(lstm_hidden_dim) 118 | 119 | # Define the fully connected layer 120 | self.fc1 = nn.Linear(lstm_hidden_dim, 128) 121 | self.fc2 = nn.Linear(128, 3) # Make sure the output dimension is 3 122 | 123 | self.dropout = nn.Dropout(0.5) 124 | 125 | def forward(self, x): 126 | # Part of CNN 127 | x = F.relu(self.conv1(x)) 128 | x = self.cbam1(x) # Adding CBAM 129 | x = self.pool(x) 130 | x = F.relu(self.conv2(x)) 131 | x = self.cbam2(x) # Adding CBAM 132 | x = self.pool(x) 133 | 134 | # The shape is adjusted to fit the input of the LSTM layer 135 | x = x.view(x.size(0), 1, -1) # (batch_size, sequence_length=1, input_size=conv_output_size) 136 | 137 | # LSTM part 138 | lstm_out, _ = self.lstm(x) 139 | 140 | # Attention mechanism 141 | attn_output = self.attention(lstm_out) 142 | 143 | # Fully connected layer 144 | x = F.relu(self.fc1(attn_output)) 145 | x = self.dropout(x) 146 | x = self.fc2(x) 147 | 148 | return F.log_softmax(x, dim=1) 149 | 150 | # Function to train the model 151 | def train(model, optimizer, train_loader, validation_loader): 152 | # Initialize a dictionary that records the loss and accuracy during training and validation 153 | history = {'train_loss': {}, 'train_acc': {}, 'valid_loss': {}, 'valid_acc': {}} 154 | # Training the model 155 | for epoch in range(num_epochs): 156 | total = 0 157 | correct = 0 158 | trainlosses = [] 159 | 160 | for batch_index, (images, labels) in enumerate(train_loader): 161 | # Send to GPU (device) 162 | images, labels = images.to(device), labels.to(device) 163 | 164 | # Forward pass 165 | outputs = model(images.float()) 166 | 167 | # Loss 168 | loss = criterion(outputs, labels) 169 | 170 | # Backward and optimize 171 | optimizer.zero_grad() 172 | loss.backward() 173 | optimizer.step() 174 | 175 | trainlosses.append(loss.item()) 176 | 177 | # Compute accuracy 178 | _, argmax = torch.max(outputs, 1) 179 | correct += (labels == argmax).sum().item() # .mean() 180 | total += len(labels) 181 | 182 | history['train_loss'][epoch] = np.mean(trainlosses) 183 | history['train_acc'][epoch] = correct / total 184 | 185 | if epoch % 10 == 0: 186 | with torch.no_grad(): 187 | 188 | losses = [] 189 | total = 0 190 | correct = 0 191 | 192 | for images, labels in validation_loader: 193 | images, labels = images.to(device), labels.to(device) 194 | 195 | # Forward pass 196 | outputs = model(images.float()) 197 | loss = criterion(outputs, labels) 198 | 199 | # Compute accuracy 200 | _, argmax = torch.max(outputs, 1) 201 | correct += (labels == argmax).sum().item() # .mean() 202 | total += len(labels) 203 | 204 | losses.append(loss.item()) 205 | 206 | history['valid_acc'][epoch] = np.round(correct / total, 3) 207 | history['valid_loss'][epoch] = np.mean(losses) 208 | 209 | print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {np.mean(losses):.4}, Acc: {correct / total:.2}') 210 | 211 | return history 212 | 213 | # Define the function to test the model 214 | def test(model, validation_loader): 215 | print('Evaluating model...') 216 | # Test 217 | model.eval() 218 | 219 | total = 0 220 | correct = 0 221 | testlosses = [] 222 | correct_labels = [] 223 | predictions = [] 224 | 225 | with torch.no_grad(): 226 | for batch_index, (images, labels) in enumerate(validation_loader): 227 | # Send to GPU (device) 228 | images, labels = images.to(device), labels.to(device) 229 | 230 | # Forward pass 231 | outputs = model(images.float()) 232 | 233 | # Compute actual probabilities 234 | probabilities = torch.exp(outputs) 235 | 236 | # Loss 237 | loss = criterion(outputs, labels) 238 | 239 | testlosses.append(loss.item()) 240 | 241 | # Compute accuracy 242 | _, argmax = torch.max(outputs, 1) 243 | correct += (labels == argmax).sum().item() # .mean() 244 | total += len(labels) 245 | 246 | correct_labels.extend(labels) 247 | predictions.extend(argmax.cpu()) 248 | 249 | test_loss = np.mean(testlosses) 250 | accuracy = np.round(correct / total, 2) 251 | print(f'Loss: {test_loss:.4}, Acc: {accuracy:.2}') 252 | 253 | # Convert to numpy arrays for F1 score calculation 254 | y_true = np.array([label.item() for label in correct_labels]) 255 | y_pred = np.array([label.item() for label in predictions]) 256 | 257 | f1 = f1_score(y_true, y_pred, average='binary') # For binary classification 258 | print(f'F1 Score: {f1:.2}') 259 | 260 | cm = confusion_matrix(y_true, y_pred) 261 | return cm, test_loss, accuracy, f1 262 | 263 | # Load the dataset 264 | df = pd.read_csv('data\merged.csv', index_col=0) 265 | # Get all subject IDs 266 | subject_id_list = df['subject'].unique() 267 | 268 | # Define a function to change labels, converting non-0 or non-1 labels to 1 269 | def change_label(label): 270 | if label == 0 or label == 1: 271 | return 0 272 | else: 273 | return 1 274 | 275 | # Apply the label-changing function to the label column of the dataset 276 | df['label'] = df['label'].apply(change_label) 277 | 278 | # Select feature columns 279 | X = df[feats[:-2]] # Exclude 'subject' and 'label' columns 280 | y = df['label'] 281 | 282 | # Split the training set and test set 283 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) 284 | 285 | # Create a StandardScaler object 286 | scaler = StandardScaler() 287 | 288 | # Fit the scaler object using the training set data and transform 289 | X_train_scaled = scaler.fit_transform(X_train) 290 | 291 | # Transform the test set data using the same scaler object 292 | X_test_scaled = scaler.transform(X_test) 293 | 294 | # Train a Random Forest model 295 | rf = RandomForestClassifier(n_estimators=100, random_state=42) 296 | # Create an RFE object, specifying the Random Forest model and the target number of features to select 297 | rfe = RFE(estimator=rf, n_features_to_select=10) 298 | 299 | # Fit the data and obtain the selected features 300 | rfe.fit(X_train_scaled, y_train) 301 | 302 | # Get the indices of the selected features 303 | selected_features_index = rfe.support_ 304 | 305 | # Use the selected feature indices to obtain the selected features 306 | selected_features = X.columns[selected_features_index] 307 | 308 | print("Top features selected by RFE:") 309 | print(selected_features) 310 | 311 | # Rebuild the dataset using the selected features 312 | df = df[selected_features.tolist() + ['label', 'subject']] 313 | 314 | # Set the batch sizes for training and testing 315 | train_batch_size = 25 316 | test_batch_size = 5 317 | 318 | # Set the device, preferring GPU if available 319 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 320 | 321 | # Set the number of training epochs 322 | num_epochs = 100 323 | 324 | # Define the loss function and optimizer 325 | criterion = nn.CrossEntropyLoss() 326 | 327 | # Initialize lists to store results 328 | histories = [] 329 | confusion_matrices = [] 330 | test_losses = [] 331 | test_accs = [] 332 | 333 | # Set the number of folds for cross-validation 334 | num_folds = 2 335 | kf = KFold(n_splits=num_folds, shuffle=True, random_state=42) 336 | fold_count = 0 # Initialize the fold counter 337 | 338 | # Initialize the maximum accuracy and the best model 339 | max_acc = 0.0 340 | f1_scores = [] 341 | best_model = None 342 | import copy 343 | 344 | for train_index, test_index in kf.split(df): 345 | fold_count += 1 346 | print(f'Final training and testing - Fold {fold_count}:') # Add this line to print the current fold number 347 | train_df, test_df = df.iloc[train_index], df.iloc[test_index] 348 | 349 | train_loader, test_loader = get_data_loaders(df, train_df['subject'].unique(), test_df['subject'].unique()) 350 | 351 | model = CNNLSTMModel(lstm_hidden_dim=100, num_lstm_layers=2).to(device) 352 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 353 | 354 | history = train(model, optimizer, train_loader, test_loader) 355 | histories.append(history) 356 | 357 | cm, test_loss, accuracy, f1 = test(model, test_loader) 358 | test_losses.append(test_loss) 359 | test_accs.append(accuracy) 360 | f1_scores.append(f1) 361 | 362 | # Test the model 363 | model.eval() 364 | correct = 0 365 | total = 0 366 | with torch.no_grad(): 367 | for data in test_loader: 368 | images, labels = data[0].to(device), data[1].to(device) 369 | outputs = model(images) 370 | _, predicted = torch.max(outputs.data, 1) 371 | total += labels.size(0) 372 | correct += (predicted == labels).sum().item() 373 | 374 | test_acc = correct / total 375 | 376 | # Check if the current model is better than the previous ones 377 | if test_acc > max_acc: 378 | max_acc = test_acc 379 | best_model = copy.deepcopy(model) 380 | print(f'New best model found with accuracy: {max_acc:.4f}') 381 | 382 | # Print the maximum test accuracy and f1-score 383 | print(f'Maximum test accuracy over {num_folds} folds: {max_acc:.4f}') 384 | max_f1 = max(f1_scores) 385 | print(f'Maximum F1-score over {num_folds} folds: {max_f1:.4f}') 386 | --------------------------------------------------------------------------------