├── figures
├── logo.png
├── emodb_cfm.png
├── iemocap_cfm.png
└── Aggregated_SigWavNet_V.png
├── requirements.txt
├── LICENSE
├── .gitignore
├── Data_exploration
└── data_exploration.ipynb
├── custom_layers.py
├── model.py
├── README.md
├── main.py
└── utils.py
/figures/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alaaNfissi/SigWavNet-Learning-Multiresolution-Signal-Wavelet-Network-for-Speech-Emotion-Recognition/HEAD/figures/logo.png
--------------------------------------------------------------------------------
/figures/emodb_cfm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alaaNfissi/SigWavNet-Learning-Multiresolution-Signal-Wavelet-Network-for-Speech-Emotion-Recognition/HEAD/figures/emodb_cfm.png
--------------------------------------------------------------------------------
/figures/iemocap_cfm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alaaNfissi/SigWavNet-Learning-Multiresolution-Signal-Wavelet-Network-for-Speech-Emotion-Recognition/HEAD/figures/iemocap_cfm.png
--------------------------------------------------------------------------------
/figures/Aggregated_SigWavNet_V.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alaaNfissi/SigWavNet-Learning-Multiresolution-Signal-Wavelet-Network-for-Speech-Emotion-Recognition/HEAD/figures/Aggregated_SigWavNet_V.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ipython==7.31.1
2 | matplotlib==3.4.3
3 | numpy==1.21.6
4 | pandas==1.3.5
5 | PyWavelets==1.1.1
6 | ray==1.12.0
7 | scikit_learn==0.23.2
8 | seaborn==0.11.2
9 | torch==1.13.1
10 | torchaudio==0.13.1
11 | tqdm==4.62.3
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2024, Alaa Nfissi
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/Data_exploration/data_exploration.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "2cba627e-c6ee-479c-8c05-76ea5fd146de",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "#!/usr/bin/env python3\n",
11 | "# -*- coding: utf-8 -*-\n",
12 | "\n",
13 | "\"\"\"\n",
14 | "Title: Dataset Preparation for Speech Emotion Recognition (SigWavNet)\n",
15 | "Author: Alaa Nfissi\n",
16 | "Date: March 31, 2024\n",
17 | "Description: This script is responsible for preparing the speech emotion recognition datasets, \n",
18 | "including loading, partitioning, and preprocessing the data for model training and evaluation.\n",
19 | "\"\"\"\n",
20 | "\n",
21 | "import os\n",
22 | "import pandas as pd\n",
23 | "import torchaudio\n",
24 | "\n",
25 | "\n",
26 | "# Define the path to the IEMOCAP dataset. This needs to be updated to the correct path where the IEMOCAP dataset is stored.\n",
27 | "iemocap_data_path = \"IEMOCAP DATA FOLDER PATH\"\n",
28 | "IEMOCAP_path = os.path.abspath(iemocap_data_path)\n",
29 | "dir_list_IEMOCAP = os.listdir(IEMOCAP_path)\n",
30 | "\n",
31 | "# Initialize lists to hold the paths to audio files and label files.\n",
32 | "records = []\n",
33 | "label_files = []\n",
34 | "\n",
35 | "# Process each session of the IEMOCAP dataset. There are five sessions in total.\n",
36 | "for i in range(1,6):\n",
37 | " # List directories for each session.\n",
38 | " wav_list = os.listdir(IEMOCAP_path+f'/IEMOCAP/Session{i}/sentences/wav/')\n",
39 | " \n",
40 | " # Extend the records list with paths to each wav file.\n",
41 | " for j in wav_list:\n",
42 | " records.extend([IEMOCAP_path+f'/IEMOCAP/Session{i}/sentences/wav/'+str(j)+'/'+k for k in os.listdir(IEMOCAP_path+f'/IEMOCAP/Session{i}/sentences/wav/'+str(j)+'/')])\n",
43 | " \n",
44 | " # List emotion label files for each session.\n",
45 | " label_list = os.listdir(IEMOCAP_path+f'/IEMOCAP/Session{i}/dialog/EmoEvaluation/')\n",
46 | " \n",
47 | " # Append label file paths to the label_files list.\n",
48 | " for k in label_list:\n",
49 | " if len(str(k).split('.')) == 2: # Check if the file name format is correct.\n",
50 | " label_files.append(IEMOCAP_path+f'/IEMOCAP/Session{i}/dialog/EmoEvaluation/'+str(k))\n",
51 | "\n",
52 | "# Create a dictionary to map label files to their corresponding audio files.\n",
53 | "dic = {}\n",
54 | "for i in label_files:\n",
55 | " dic.update({i : [j for j in records if j.split('/')[14].startswith(i.split('/')[13].split('.')[0])]})\n",
56 | "\n",
57 | "# Map audio file paths to their respective emotions.\n",
58 | "segments_emotions = {}\n",
59 | "for i in dic.keys():\n",
60 | " with open(i) as f:\n",
61 | " for line in f:\n",
62 | " if i.split('/')[13].split('.')[0] in line:\n",
63 | " segments_emotions.update({ [j for j in dic.get(i) if line.split('\\t')[1]+'.wav' in j][0] \n",
64 | " : line.split('\\t')[2] })\n",
65 | "\n",
66 | "# Convert the dictionary to a DataFrame and save it as a CSV file.\n",
67 | "IEMOCAP_df = pd.DataFrame({'path': segments_emotions.keys(), 'source': 'IEMOCAP', 'label':segments_emotions.values()})\n",
68 | "IEMOCAP_df.to_csv('IEMOCAP_dataset.csv', index=False)\n",
69 | "IEMOCAP_df.head()"
70 | ]
71 | }
72 | ],
73 | "metadata": {
74 | "kernelspec": {
75 | "display_name": "Python 3 (ipykernel)",
76 | "language": "python",
77 | "name": "python3"
78 | },
79 | "language_info": {
80 | "codemirror_mode": {
81 | "name": "ipython",
82 | "version": 3
83 | },
84 | "file_extension": ".py",
85 | "mimetype": "text/x-python",
86 | "name": "python",
87 | "nbconvert_exporter": "python",
88 | "pygments_lexer": "ipython3",
89 | "version": "3.7.7"
90 | }
91 | },
92 | "nbformat": 4,
93 | "nbformat_minor": 5
94 | }
95 |
--------------------------------------------------------------------------------
/custom_layers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Title: Custom Layers for Speech Emotion Recognition Model (SigWavNet)
6 | Author: Alaa Nfissi
7 | Date: March 31, 2024
8 | Description: This file defines custom neural network layers and attention mechanisms
9 | specifically designed for enhancing the speech emotion recognition model's performance.
10 | """
11 |
12 | import torch
13 | import torch.nn as nn
14 | from collections import OrderedDict
15 | import torch.nn.functional as F
16 |
17 | d = 0
18 | device = torch.device(f"cuda:{d}")
19 |
20 | class Kernel(nn.Module):
21 |
22 | """
23 | Represents a learnable kernel, which can be initialized either randomly or with a specified array.
24 | """
25 |
26 | def __init__(self, kernelInit=20, trainKern=True):
27 |
28 | """
29 | Initializes the Kernel object.
30 |
31 | Parameters:
32 | - kernelInit (int, list): Initial size of the kernel or the initial array values.
33 | - trainKern (bool): Specifies whether the kernel is learnable.
34 | """
35 |
36 | super(Kernel, self).__init__()
37 | self.trainKern = trainKern
38 |
39 | if isinstance(kernelInit, int):
40 | self.kernelSize = kernelInit
41 | self.kernel = nn.Parameter(torch.empty(self.kernelSize,), requires_grad=self.trainKern)
42 | nn.init.normal_(self.kernel)
43 | else:
44 | self.kernelSize = len(kernelInit)
45 | self.kernel = nn.Parameter(torch.Tensor(kernelInit).view(self.kernelSize,), requires_grad=self.trainKern)
46 |
47 | def forward(self, inputs):
48 |
49 | """
50 | Forward pass for the Kernel. It returns the kernel itself since it's a learnable parameter.
51 | """
52 |
53 | return self.kernel
54 |
55 |
56 | class LowPassWave(nn.Module):
57 |
58 | """
59 | Performs low-pass filtering on the input signal using a convolution operation with stride 2.
60 | """
61 |
62 | def __init__(self):
63 | super(LowPassWave, self).__init__()
64 |
65 | def forward(self, inputs):
66 |
67 | """
68 | Applies low-pass filtering by convolving the input with the kernel.
69 |
70 | Parameters:
71 | - inputs (tuple): A tuple containing the input signal and the filter kernel.
72 |
73 | Returns:
74 | - The low-pass filtered signal.
75 | """
76 |
77 | return F.conv1d(inputs[0], inputs[1].view(1,1,-1).to(device), padding=0, stride=2)
78 |
79 |
80 | class HighPassWave(nn.Module):
81 |
82 | """
83 | Performs high-pass filtering by convolving the input signal with a reversed and sign-flipped kernel.
84 | """
85 |
86 | def __init__(self):
87 | super(HighPassWave, self).__init__()
88 |
89 | def initialize_qmfFlip(self, input_shape):
90 |
91 | """
92 | Initializes the sign-flipping tensor used for high-pass filtering.
93 |
94 | Parameters:
95 | - input_shape (tuple): Shape of the input signal.
96 | """
97 |
98 | qmfFlip = torch.tensor([(-1) ** i for i in range(input_shape[0])], dtype=torch.float32)
99 | self.qmfFlip = nn.Parameter(qmfFlip.view(1, 1, -1), requires_grad=False).to(device)
100 |
101 | def forward(self, inputs):
102 |
103 | """
104 | Applies high-pass filtering by convolving the input with a reversed and sign-flipped kernel.
105 |
106 | Parameters:
107 | - inputs (tuple): A tuple containing the input signal and the filter kernel.
108 |
109 | Returns:
110 | - The high-pass filtered signal.
111 | """
112 |
113 | if not hasattr(self, 'qmfFlip'):
114 |
115 | self.initialize_qmfFlip(inputs[1].shape)
116 |
117 | return F.conv1d(inputs[0], torch.mul(torch.flip(inputs[1], [0]).to(device), self.qmfFlip), padding=0, stride=2)
118 |
119 |
120 |
121 | class HardThresholdAssym(nn.Module):
122 |
123 | """
124 | Implements an asymmetrical hard-thresholding function that is learnable and can be applied to input signals.
125 | """
126 |
127 | def __init__(self, init=None, alpha=None, trainBias=True):
128 |
129 | """
130 | Initializes the HardThresholdAssym object.
131 |
132 | Parameters:
133 | - init (float, optional): Initial threshold value.
134 | - alpha (float, optional): Sharpness parameter of the sigmoid function.
135 | - trainBias (bool): Specifies whether the threshold values are learnable.
136 | """
137 |
138 | super(HardThresholdAssym, self).__init__()
139 | self.trainBias = trainBias
140 |
141 | if isinstance(init, float) or isinstance(init, int):
142 | self.init = torch.tensor([init], dtype=torch.float32)
143 | else:
144 | self.init = torch.ones(1, dtype=torch.float32)
145 |
146 | if isinstance(alpha, float) or isinstance(alpha, int):
147 | self.alpha = torch.tensor([alpha], dtype=torch.float32)
148 | else:
149 | self.alpha = torch.ones(1, dtype=torch.float32)
150 |
151 | if torch.cuda.is_available():
152 | self.init = self.init.to(device)
153 | self.alpha = self.alpha.to(device)
154 |
155 | self.thrP = nn.Parameter(self.init, requires_grad=self.trainBias)
156 | self.thrN = nn.Parameter(self.init, requires_grad=self.trainBias)
157 |
158 | self.alpha = nn.Parameter(self.alpha, requires_grad=self.trainBias)
159 |
160 | def forward(self, inputs):
161 |
162 | """
163 | Applies the asymmetric hard thresholding to the input signals.
164 |
165 | Parameters:
166 | - inputs (Tensor): Input signal tensor.
167 |
168 | Returns:
169 | - Thresholded signal tensor.
170 | """
171 |
172 | return inputs * (
173 | torch.sigmoid(self.alpha * (inputs - self.thrP))
174 | + torch.sigmoid(-self.alpha * (inputs + self.thrN))
175 | )
176 |
177 | class SpatialAttentionBlock(nn.Module):
178 |
179 | """
180 | Implements a linear attention block that applies spatial attention mechanism over 1D signals.
181 | """
182 |
183 | def __init__(self, in_features, normalize_attn=True):
184 |
185 | """
186 | Initializes the LinearAttentionBlock object.
187 |
188 | Parameters:
189 | - in_features (int): Number of input features.
190 | - normalize_attn (bool): Specifies whether to normalize attention weights using softmax.
191 | """
192 |
193 | super(SpatialAttentionBlock, self).__init__()
194 | self.normalize_attn = normalize_attn
195 | self.op = nn.Conv1d(in_channels=in_features, out_channels=1, kernel_size=1, padding=0, bias=False)
196 | def forward(self, l, g):
197 |
198 | """
199 | Applies attention to the input features.
200 |
201 | Parameters:
202 | - l (Tensor): Local input features.
203 | - g (Tensor): Global context.
204 |
205 | Returns:
206 | - A tuple of attention scores and attended feature map.
207 | """
208 |
209 | N, C, W = l.size()
210 | c = self.op(l+g) # batch_sizex1xW
211 | if self.normalize_attn:
212 | a = F.softmax(c.view(N,1,-1), dim=2).view(N,1,W)
213 | else:
214 | a = torch.sigmoid(c)
215 | g = torch.mul(a.expand_as(l), l)
216 | if self.normalize_attn:
217 | g = g.view(N,C,-1).sum(dim=2) # batch_sizexC
218 | else:
219 | g = F.adaptive_avg_pool1d(g, (1,1)).view(N,C)
220 | return c.view(N,1,W), g
221 |
222 |
223 | class TemporalAttn(nn.Module):
224 |
225 | """
226 | Implements temporal attention mechanism over sequences of hidden states.
227 | """
228 |
229 | def __init__(self, hidden_size):
230 |
231 | """
232 | Initializes the TemporalAttn object.
233 |
234 | Parameters:
235 | - hidden_size (int): Dimensionality of the hidden states.
236 | """
237 |
238 | super(TemporalAttn, self).__init__()
239 | self.hidden_size = hidden_size
240 | self.fc1 = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
241 | self.fc2 = nn.Linear(self.hidden_size*2, self.hidden_size, bias=False)
242 |
243 | def forward(self, hidden_states):
244 |
245 | """
246 | Applies temporal attention to the sequence of hidden states.
247 |
248 | Parameters:
249 | - hidden_states (Tensor): A tensor of shape (batch_size, time_steps, hidden_size) containing the hidden states.
250 |
251 | Returns:
252 | - The attention vector and attention weights.
253 | """
254 |
255 | # (batch_size, time_steps, hidden_size)
256 | score_first_part = self.fc1(hidden_states)
257 | # (batch_size, hidden_size)
258 | h_t = hidden_states[:,-1,:]
259 | # (batch_size, time_steps)
260 | score = torch.bmm(score_first_part, h_t.unsqueeze(2)).squeeze(2)
261 | attention_weights = F.softmax(score, dim=1)
262 | # (batch_size, hidden_size)
263 | context_vector = torch.bmm(hidden_states.permute(0,2,1), attention_weights.unsqueeze(2)).squeeze(2)
264 | # (batch_size, hidden_size*2)
265 | pre_activation = torch.cat((context_vector, h_t), dim=1)
266 | # (batch_size, hidden_size)
267 | attention_vector = self.fc2(pre_activation)
268 | attention_vector = torch.tanh(attention_vector)
269 |
270 | return attention_vector, attention_weights
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | Title: Speech Emotion Recognition Model Definition (SigWavNet)
6 | Author: Alaa Nfissi
7 | Date: March 31, 2024
8 | Description: Defines the architecture of the speech emotion recognition model (SigWavNet), incorporating
9 | custom layers, attention mechanisms, and the overall neural network structure.
10 | """
11 |
12 |
13 | from custom_layers import *
14 |
15 | class CNN1DSABiGRUTA(nn.Module):
16 |
17 | """
18 | This class implements a Convolutional Neural Network (CNN) with 1D dilated convolutions,
19 | Spatial Attention, Bidirectional Gated Recurrent Units (Bi-GRU), and Temporal Attention layers.
20 |
21 | Parameters:
22 | - n_input: The number of input channels.
23 | - n_channel: The number of output channels for the convolutional layers.
24 | - hidden_dim: The hidden dimension size for the GRU layers.
25 | - n_layers: The number of layers in the GRU.
26 | - normalize_attn: Boolean indicating whether to normalize attention weights.
27 | """
28 |
29 | def __init__(self, n_input, n_channel, hidden_dim, n_layers, normalize_attn=True):
30 | super().__init__()
31 | self.n_channel = n_channel
32 | self.n_input = n_input
33 | self.normalize_attn = normalize_attn
34 | self.hidden_dim = hidden_dim
35 | self.n_layers = n_layers
36 |
37 |
38 | ################################################ 1D CNN ################################################################
39 |
40 | # Define the first convolutional layer with dilation
41 | self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=7, stride=4, dilation=3)
42 | self.in1 = nn.InstanceNorm1d(n_channel)
43 | self.relu1 = nn.LeakyReLU()
44 |
45 | # Define the second convolutional layer with dilation
46 | self.conv1_1 = nn.Conv1d(n_channel, 2*n_channel, kernel_size=5, stride=4, dilation=2)
47 | self.in1_1 = nn.InstanceNorm1d(2*n_channel)
48 | self.relu1_1 = nn.LeakyReLU()
49 |
50 |
51 | ############################################ Spatial Attention #########################################################
52 |
53 | # Define spatial attention components
54 | self.dense1 = nn.Conv1d(in_channels=2*n_channel, out_channels=2*n_channel, kernel_size=7, padding=3, bias=True)
55 | self.in1_2 = nn.InstanceNorm1d(2*n_channel)
56 | self.relu1_2 = nn.LeakyReLU()
57 |
58 | # Define the spatial attention layer
59 | self.spatialAttn= SpatialAttentionBlock(in_features=2*n_channel, normalize_attn=self.normalize_attn)
60 |
61 | self.reluAtt1 = nn.LeakyReLU()
62 |
63 |
64 | #################################################### Bi-GRU #############################################################
65 |
66 | # Define the Bi-GRU layer
67 | self.gru1 = nn.GRU(2*n_channel, self.hidden_dim, self.n_layers, batch_first=True, bidirectional=True, dropout=0)
68 |
69 |
70 | ############################################ Temporal Attention #########################################################
71 |
72 | # Define the temporal attention layer
73 | self.tempAttn = TemporalAttn(hidden_size=2*hidden_dim)
74 |
75 |
76 | def forward(self, x, h):
77 |
78 | """
79 | Forward pass of the CNN1DSABiGRUTA model.
80 |
81 | Parameters:
82 | - x: The input tensor.
83 | - h: The initial hidden state for the GRU layer.
84 |
85 | Returns:
86 | - The output tensor after passing through the CNN, spatial attention, Bi-GRU, and temporal attention layers.
87 | - The updated hidden state.
88 | """
89 |
90 | x = self.conv1(x)
91 | x = self.relu1(self.in1(x))
92 |
93 |
94 | x = self.conv1_1(x)
95 | x = self.relu1_1(self.in1_1(x))
96 |
97 | g1 = self.dense1(x)
98 | g1 = self.relu1_2(self.in1_2(g1))
99 |
100 | c, g_attended_1 = self.spatialAttn(x, g1)
101 |
102 | x = self.reluAtt1(x + g_attended_1.unsqueeze(2)) # Residual connection with attended global features
103 |
104 | x = x.permute(0, 2, 1)
105 |
106 | x, h = self.gru1(x, h)
107 |
108 | x, weights = self.tempAttn(x)
109 |
110 | x = x.unsqueeze(1)
111 |
112 | return x, h
113 |
114 | class ChannelWeighting(nn.Module):
115 |
116 | """
117 | This layer applies a learnable weighting to the channels of its input tensor.
118 |
119 | Parameters:
120 | - num_channels: The number of channels in the input tensor.
121 | """
122 |
123 | def __init__(self, num_channels):
124 | super(ChannelWeighting, self).__init__()
125 |
126 | # Initialize weights for each channel. These weights are learnable parameters.
127 | self.weights = nn.Parameter(torch.ones(num_channels))
128 |
129 | def forward(self, x):
130 |
131 | """
132 | Forward pass of the ChannelWeighting layer.
133 |
134 | Parameters:
135 | - x: Input tensor of shape (batch_size, num_channels, length).
136 |
137 | Returns:
138 | - The input tensor with each channel weighted by the learned weights.
139 | """
140 |
141 | # Apply weights to each channel. The weights are broadcasted to match the input tensor shape.
142 | return x * self.weights.view(1, -1, 1)
143 |
144 | class SigWavNet(nn.Module):
145 |
146 | """
147 | SigWavNet model combines a Learnable Fast Discrete Wavelet Transform (LFDWT) with a 1D CNN,
148 | spatial attention, Bi-GRU, and temporal attention layers for speech emotion recognition (SER).
149 |
150 | Parameters:
151 | - n_input: Number of input channels.
152 | - hidden_dim: Dimension of hidden layers in GRU.
153 | - n_layers: Number of layers in the GRU.
154 | - n_output: Number of output classes.
155 | - stride: Stride size for convolution operations.
156 | - n_channel: Number of channels in convolution layers.
157 | - inputSize: Size of the input signal.
158 | - kernelInit: Initial value or method for kernel initialization.
159 | - kernTrainable: Indicates if kernels are trainable.
160 | - level: Number of decomposition levels in wavelet transform.
161 | - kernelsConstraint: Type of constraint on kernels ('CQF', 'PerLayer', or 'PerFilter').
162 | - initHT: Initial value for hard thresholding.
163 | - trainHT: Indicates if hard thresholding is trainable.
164 | - alpha: Alpha value for LAHT.
165 | """
166 |
167 | def __init__(self, n_input, hidden_dim, n_layers, n_output, stride=2, n_channel=128,
168 | inputSize=None, kernelInit=20, kernTrainable=True, level=1, kernelsConstraint='QMF', initHT=1.0, trainHT=True, alpha=10):
169 | super().__init__()
170 | self.n_input = n_input
171 | self.n_channel = n_channel
172 | self.n_output = n_output
173 | self.hidden_dim = hidden_dim
174 | self.n_layers = n_layers
175 | self.inputSize = inputSize
176 | self.kernelInit = kernelInit
177 | self.kernTrainable = kernTrainable
178 | self.level = level
179 | self.kernelsConstraint = kernelsConstraint
180 | self.initHT = initHT
181 | self.trainHT = trainHT
182 | self.alpha = alpha
183 |
184 | # Initialization of kernels based on the constraint specified
185 | if self.kernelsConstraint=='CQF':
186 | kr = Kernel(self.kernelInit, trainKern=self.kernTrainable)
187 | self.kernelsG_ = nn.ModuleList([kr for l in range(self.level)])
188 | self.kernelsG = [kern(None) for kern in self.kernelsG_]
189 | self.kernelsH = self.kernelsG
190 |
191 | elif self.kernelsConstraint=='PerLayer':
192 | self.kernelsG_ = nn.ModuleList([Kernel(self.kernelInit, trainKern=self.kernTrainable) for lev in range(self.level)])
193 | self.kernelsG = [kern(None) for kern in self.kernelsG_]
194 | self.kernelsH = self.kernelsG
195 |
196 | elif self.kernelsConstraint=='PerFilter':
197 | self.kernelsG_ = nn.ModuleList([Kernel(self.kernelInit, trainKern=self.kernTrainable) for lev in range(self.level)])
198 | self.kernelsG = [kern(None) for kern in self.kernelsG_]
199 | self.kernelsH_ = nn.ModuleList([Kernel(self.kernelInit, trainKern=self.kernTrainable) for lev in range(self.level)])
200 | self.kernelsH = [kern(None) for kern in self.kernelsH_]
201 |
202 |
203 | # Wavelet transform layers
204 | self.LowPassWave = nn.ModuleList([LowPassWave() for lev in range(self.level)])
205 | self.HighPassWave = nn.ModuleList([HighPassWave() for lev in range(self.level)])
206 |
207 | self.HardThresholdAssymH = nn.ModuleList([HardThresholdAssym(init=self.initHT,trainBias=self.trainHT, alpha=self.alpha) for lev in range(self.level)])
208 |
209 | #####################################################################################################
210 |
211 | # CNN, spatial attention, and Bi-GRU layers for each wavelet transform level and the low-frequency component
212 | self.conv1ds = nn.ModuleList([CNN1DSABiGRUTA(self.n_input, self.n_channel, self.hidden_dim, self.n_layers, True) for i in range(self.level+1)])
213 |
214 |
215 | #####################################################################################################
216 |
217 | # Channel weighting to combine the wavelet transform levels and the low-frequency component
218 | self.channel_weighting = ChannelWeighting(self.level+1)
219 |
220 | # Final convolutional layer to produce the output
221 | self.conv1_3 = nn.Conv1d(self.level+1, self.n_output, kernel_size=5, stride=1, dilation=1)
222 | self.in1_3 = nn.InstanceNorm1d(self.n_output)
223 | self.relu1_3 = nn.LeakyReLU()
224 |
225 |
226 | def forward(self, x, h):
227 |
228 | """
229 | Forward pass through the SigWavNet model.
230 |
231 | Parameters:
232 | - x: Input tensor of shape (batch_size, n_input, sequence_length).
233 | - h: Initial hidden states for the GRU layers.
234 |
235 | Returns:
236 | - The output tensor of shape (batch_size, n_output).
237 | - The updated hidden states.
238 | """
239 |
240 | # Decomposition using wavelet transform and processing with CNNs, spatial attention, and Bi-GRU
241 | wav_coef = torch.Tensor().to(device)
242 |
243 | for lev in range(self.level):
244 |
245 | hl = self.HighPassWave[lev]([x, self.kernelsH[lev]])
246 | hl = self.HardThresholdAssymH[lev](hl)
247 | x = self.LowPassWave[lev]([x, self.kernelsG[lev]])
248 | hl, h[lev] = self.conv1ds[lev](hl, h[lev])
249 | wav_coef = torch.cat((wav_coef, hl),1)
250 |
251 | ########################################################
252 |
253 | # Processing the low-frequency component
254 | x, h[self.level] = self.conv1ds[self.level](x, h[self.level])
255 |
256 |
257 | x = torch.cat((x, wav_coef),1)
258 |
259 | ################## Channel weighting ##################
260 |
261 |
262 | # Applying channel weighting
263 | x = self.channel_weighting(x)
264 |
265 | ########################################################
266 |
267 | # Final processing to produce the output
268 | x = self.conv1_3(x)
269 | x = self.relu1_3(self.in1_3(x))
270 |
271 | # Global average pooling
272 | x = x.mean(2)
273 |
274 | output = F.log_softmax(x, dim=1)
275 | return output , h
276 |
277 | def init_hidden(self, batch_size):
278 |
279 | """
280 | Initializes hidden states for each GRU layer.
281 |
282 | Parameters:
283 | - batch_size: The batch size.
284 |
285 | Returns:
286 | - A list of initial hidden states for each Bi-GRU layer.
287 | """
288 |
289 | weight = next(self.parameters()).data
290 | hidden = weight.new(self.n_layers*2, batch_size, self.hidden_dim).zero_().to(device)
291 | return [hidden for i in range(self.level+1)]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
14 | This paper has been accepted for publication in IEEE Transactions on Affective Computing.
15 |
16 |
31 | · 32 | Report Bug 33 | · 34 | Request Feature 35 |
36 |In the field of human-computer interaction and psychological assessment, speech emotion recognition (SER) plays an important role in deciphering emotional states from speech signals. Despite advancements, challenges persist due to system complexity, feature distinctiveness issues, and noise interference. This paper introduces a new end-to-end (E2E) deep learning multi-resolution framework for SER, addressing these limitations by extracting meaningful representations directly from raw waveform speech signals. By leveraging the properties of the fast discrete wavelet transform (FDWT), including the cascade algorithm, conjugate quadrature filter, and coefficient denoising, our approach introduces a learnable model for both wavelet bases and denoising through deep learning techniques. The framework incorporates an activation function for learnable asymmetric hard thresholding of wavelet coefficients. Our approach exploits the capabilities of wavelets for effective localization in both time and frequency domains. We then combine one-dimensional dilated convolutional neural networks (1D dilated CNN) with a spatial attention layer and bidirectional gated recurrent units (Bi-GRU) with a temporal attention layer to efficiently capture the nuanced spatial and temporal characteristics of emotional features. By handling variable-length speech without segmentation and eliminating the need for pre or post-processing, the proposed model outperformed state-of-the-art methods on IEMOCAP and EMO-DB datasets.
72 |99 | To ensure consistency and compatibility across our datasets, we first convert all audio signals to a uniform 16 KHz sampling rate and mono-channel format. We then divide each dataset into two primary subsets: 90% for training and validation purposes, and the remaining 10% designated for testing as unseen data. For the training and validation segments, we implement a 10-fold cross-validation method. This partitioning and the allocation within the cross-validation folds leverage stratified random sampling, a method that organizes the dataset into homogenous strata based on emotional categories. Unlike basic random sampling, this approach guarantees a proportional representation of each class, leading to a more equitable and representative dataset division.
100 | 101 |102 | In the quest to identify optimal hyperparameters for our model, we utilize a grid search strategy. Hyperparameter tuning can be approached in several ways, including the use of scheduling algorithms. These schedulers can efficiently manage trials by early termination of less promising ones, as well as pausing, duplicating, or modifying the hyperparameters of ongoing trials. For its effectiveness and performance, we have selected the Asynchronous Successive Halving Algorithm (ASHA) as our optimization technique. 103 | The data preprocessing used in this study is provided in the `Data_exploration` folder. 104 |
105 | 106 | ### Getting the code 107 | 108 | You can download a copy of all the files in this repository by cloning the 109 | [git](https://git-scm.com/) repository: 110 | 111 | git clone https://github.com/alaaNfissi/SigWavNet-Learning-Multiresolution-Signal-Wavelet-Network-for-Speech-Emotion-Recognition.git 112 | 113 | or [download a zip archive](https://github.com/alaaNfissi/SigWavNet-Learning-Multiresolution-Signal-Wavelet-Network-for-Speech-Emotion-Recognition/archive/refs/heads/main.zip). 114 | 115 | ### Dependencies 116 | 117 |118 | You'll need a working Python environment to run the code. 119 | The recommended way to set up your environment is through the 120 | [Anaconda Python distribution](https://www.anaconda.com/download/) which 121 | provides the `conda` package manager. 122 | Anaconda can be installed in your user directory and does not interfere with 123 | the system Python installation. 124 | The required dependencies are specified in the file `requirements.txt`. 125 | We use `conda` virtual environments to manage the project dependencies in 126 | isolation. 127 | Thus, you can install our dependencies without causing conflicts with your 128 | setup (even with different Python versions). 129 | Run the following command to create an `ser-env` environment to create a separate environment: 130 | 131 | ```sh 132 | conda create --name ser-env 133 | ``` 134 | 135 | Activate the environment, this will enable it for your current terminal session. Any subsequent commands will use software that is installed in the environment: 136 | 137 | ```sh 138 | conda activate ser-env 139 | ``` 140 | 141 | Use Pip to install packages to the Anaconda Environment: 142 | 143 | ```sh 144 | conda install pip 145 | ``` 146 | 147 | Install all required dependencies in it: 148 | 149 | ```sh 150 | pip install -r requirements.txt 151 | ``` 152 | 153 |
154 | 155 | ### Reproducing the results 156 | 157 |158 | 159 | 1. First, you need to download IEMOCAP and EMO-DB datasets: 160 | * [IEMOCAP official website](https://sail.usc.edu/iemocap/) 161 | * [EMO-DB official website](http://www.emodb.bilderbar.info/download/) 162 | 163 | 2. To be able to explore the data you need to execute the Jupyter Notebook that prepares the `csv` files needed for the experiments. 164 | To do this, you must first start the notebook server by going into the 165 | repository top level and running: 166 | ```sh 167 | jupyter notebook 168 | ``` 169 | This will start the server and open your default web browser to the Jupyter 170 | interface. On the page, go into the `Data_exploration` folder and select the 171 | `data_exploration.ipynb` notebook to view/run. Make sure to specify the correct dataset paths on your machine as described in the notebook. 172 | The notebook is divided into cells (some have text while others have code). 173 | Each cell can be executed using `Shift + Enter`. 174 | Executing text cells does nothing and executing code cells runs the code and produces its output. 175 | To execute the whole notebook, run all cells in order. 176 | 177 | 3. After generating the needed `csv` files `IEMOCAP_dataset.csv` and `EMO_DB_dataset.csv`, go to your terminal where the `ser-env` environment was 178 | activated go to the project folder and run the python script `main.py` as follows: 179 | 180 | ```sh 181 | python main.py 182 | ``` 183 | _You can do the same thing for the EMO-DB dataset by changing the dataset csv file to `EMO_DB_dataset.csv`._ 184 | 185 |
186 | 187 | 188 | 189 | ## Results 190 | 191 | ### On IEMOCAP dataset 192 |193 | The trials showcase the proficiency of the SigWavNet model in recognizing diverse emotional expressions from the IEMOCAP dataset. This model achieves notable accuracy in distinguishing between various emotions, as indicated by its performance metrics—precision, recall, and F1-score—across different emotional categories. Specifically, SigWavNet performs exceptionally well in identifying 'Neutral' emotions, achieving a high precision rate of 97% and a recall rate of 93% (refer to the paper). This underscores the model's strength in accurately pinpointing this particular emotional state. The confusion matrix in SigWavNet confusion matrix figure describes class-wise test results on IEMOCAP. 194 |
195 | 196 | SigWavNet confusion matrix on IEMOCAP | 197 | :-----------------------------------------------------------------:| 198 |  | 199 | 200 | 201 | ### On EMO-DB dataset 202 |203 | The evaluation of SigWavNet on the EMO-DB dataset provides a comprehensive analysis of its ability to distinguish between various emotional states, as demonstrated by its commendable precision, recall, and F1-score metrics for different emotions. Particularly notable is the model's performance on 'Anger', where it achieves an exceptional precision rate of 100%, reflecting its precision in predicting this specific emotion. Alongside a recall rate of 92.3%, SigWavNet effectively identifies the majority of 'Anger' instances, leading to a harmoniously balanced F1-score of 96% (refer to the paper). The confusion matrix in SigWavNet confusion matrix figure describes class-wise test results on EMO-DB. 204 |
205 | 206 | SigWavNet confusion matrix on EMO-DB | 207 | :-----------------------------------------------------------------:| 208 |  | 209 | 210 | 211 | 212 |213 | 214 | _For more detailed experiments and results you can read the paper._ 215 |
216 | 217 | 218 | 219 | ## Contributing 220 | 221 | Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**. 222 | 223 | If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement". 224 | Don't forget to give the project a star! Thanks again! 225 | 226 | 1. Fork the Project 227 | 2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`) 228 | 3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`) 229 | 4. Push to the Branch (`git push origin feature/AmazingFeature`) 230 | 5. Open a Pull Request 231 | 232 | 233 | 234 | 235 | 236 | 237 | ## License 238 | 239 | All source code is made available under a BSD 3-clause license. You can freely 240 | use and modify the code, without warranty, so long as you provide attribution 241 | to the authors. See `LICENSE.md` for the full license text. 242 | 243 | 244 | 245 | 246 | 247 | 248 | ## Contact 249 | 250 | Alaa Nfissi - [@LinkedIn](https://www.linkedin.com/in/alaa-nfissi/) - alaa.nfissi@mail.concordia.ca 251 | 252 | Github Link: [https://github.com/alaaNfissi/SigWavNet-Learning-Multiresolution-Signal-Wavelet-Network-for-Speech-Emotion-Recognition](https://github.com/alaaNfissi/SigWavNet-Learning-Multiresolution-Signal-Wavelet-Network-for-Speech-Emotion-Recognition) 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | [contributors-shield]: https://img.shields.io/github/contributors/othneildrew/Best-README-Template.svg?style=for-the-badge 262 | [contributors-url]: https://github.com/othneildrew/Best-README-Template/graphs/contributors 263 | [forks-shield]: https://img.shields.io/github/forks/othneildrew/Best-README-Template.svg?style=for-the-badge 264 | [forks-url]: https://github.com/othneildrew/Best-README-Template/network/members 265 | [stars-shield]: https://img.shields.io/github/stars/othneildrew/Best-README-Template.svg?style=for-the-badge 266 | [stars-url]: https://github.com/othneildrew/Best-README-Template/stargazers 267 | [issues-shield]: https://img.shields.io/github/issues/othneildrew/Best-README-Template.svg?style=for-the-badge 268 | [issues-url]: https://github.com/othneildrew/Best-README-Template/issues 269 | [license-shield]: https://img.shields.io/github/license/othneildrew/Best-README-Template.svg?style=for-the-badge 270 | [license-url]: https://github.com/othneildrew/Best-README-Template/blob/master/LICENSE.txt 271 | [linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=for-the-badge&logo=linkedin&colorB=555 272 | [linkedin-url]: https://linkedin.com/in/othneildrew 273 | [model-architecture]: figures/Aggregated_SigWavNet_V.png 274 | 275 | 276 | [anaconda.com]: https://anaconda.org/conda-forge/mlconjug/badges/version.svg 277 | [anaconda-url]: https://anaconda.org/conda-forge/mlconjug 278 | 279 | [React.js]: https://img.shields.io/badge/React-20232A?style=for-the-badge&logo=react&logoColor=61DAFB 280 | [React-url]: https://reactjs.org/ 281 | [Vue.js]: https://img.shields.io/badge/Vue.js-35495E?style=for-the-badge&logo=vuedotjs&logoColor=4FC08D 282 | [Vue-url]: https://vuejs.org/ 283 | [Angular.io]: https://img.shields.io/badge/Angular-DD0031?style=for-the-badge&logo=angular&logoColor=white 284 | [Angular-url]: https://angular.io/ 285 | [Svelte.dev]: https://img.shields.io/badge/Svelte-4A4A55?style=for-the-badge&logo=svelte&logoColor=FF3E00 286 | [Svelte-url]: https://svelte.dev/ 287 | [Laravel.com]: https://img.shields.io/badge/Laravel-FF2D20?style=for-the-badge&logo=laravel&logoColor=white 288 | [Laravel-url]: https://laravel.com 289 | [Bootstrap.com]: https://img.shields.io/badge/Bootstrap-563D7C?style=for-the-badge&logo=bootstrap&logoColor=white 290 | [Bootstrap-url]: https://getbootstrap.com 291 | [JQuery.com]: https://img.shields.io/badge/jQuery-0769AD?style=for-the-badge&logo=jquery&logoColor=white 292 | [JQuery-url]: https://jquery.com 293 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Title: Main Execution Script for Speech Emotion Recognition (SigWavNet) 6 | Author: Alaa Nfissi 7 | Date: March 31, 2024 8 | Description: The main script to train, validate, and test the speech emotion recognition model (SigWavNet). 9 | It includes the configuration and execution of experiments, model evaluation, and result reporting. 10 | """ 11 | 12 | from utils import * # Import utility functions, classes, and global variables 13 | from model import * # Import model definition 14 | from custom_layers import * # Import custom layer definitions 15 | 16 | 17 | # Check and print CUDA availability and details 18 | print(f"Is cuda available ? => {torch.cuda.is_available()}") 19 | print(f"How many devices available ? => {torch.cuda.device_count()}") 20 | print(f"Current device => {torch.cuda.current_device()}") 21 | d = 0 # Device ID to use 22 | print(f"Picked device => {d}") 23 | print(f"Device's name => {torch.cuda.get_device_name(d)}") 24 | device = torch.device(f"cuda:{d}") # Set the device for training 25 | 26 | 27 | # Setup for data loaders based on CUDA availability 28 | if device == f"cuda{d}": 29 | num_workers = 10 # Use more workers for faster data loading on CUDA 30 | pin_memory = True # Helps with faster transfer to CUDA device 31 | else: 32 | num_workers = 0 # CPU mode requires fewer workers 33 | pin_memory = False # No pinning needed in CPU mode 34 | 35 | 36 | # Define paths for experiments and model checkpoints 37 | iemocap_experiments_folder = "experiments" 38 | checkpoint_dir = "models" 39 | 40 | 41 | def train_SigWavNet(config, checkpoint_dir=None, data=None, max_num_epochs=None): 42 | 43 | """ 44 | Trains the SigWavNet model with given configuration, data, and number of epochs. 45 | 46 | Parameters: 47 | - config: A dictionary with configuration parameters for the model and training. 48 | - checkpoint_dir: Directory where checkpoints are saved. 49 | - data: The dataset to use for training and validation. 50 | - max_num_epochs: The maximum number of epochs to train for. 51 | 52 | Returns: 53 | - None 54 | """ 55 | 56 | epoch_count = max_num_epochs 57 | log_interval = 20 # Interval for logging training progress 58 | 59 | # Prepare class weights for focal loss 60 | a, class_counts = np.unique(data['label'], return_counts=True) 61 | num_classes = len(class_counts) 62 | total_samples = len(data['label']) 63 | 64 | class_weights = [] 65 | for count in class_counts: 66 | weight = 1 / (count / total_samples) 67 | class_weights.append(weight) 68 | print(class_weights) 69 | 70 | # Initialize criterion with focal loss 71 | criterion = FocalLoss(alpha=torch.FloatTensor(class_weights), gamma=2) 72 | 73 | 74 | # Initialize SigWavNet model with given config 75 | model_SigWavNet = SigWavNet(n_input=config['n_input'], hidden_dim=config['hidden_dim'], n_layers=config['n_layers'] , n_output=config['n_output'], inputSize=None, kernelInit=config['kernelInit'], kernTrainable=config['kernTrainable'], level=config['level'], kernelsConstraint=config['mode'], initHT=config['initHT'], trainHT=config['trainHT'], alpha=config['alpha']) 76 | 77 | device = "cpu" 78 | if torch.cuda.is_available(): 79 | device = "cuda:0" # Use first CUDA device 80 | pin_memory = True 81 | if torch.cuda.device_count() > 1: 82 | model_SigWavNet = nn.DataParallel(model_SigWavNet) # Use DataParallel for multi-GPU 83 | model_SigWavNet.to(device) # Move model to the chosen device 84 | 85 | # Initialize optimizer and learning rate scheduler 86 | optimiser = optim.Adam(model_SigWavNet.parameters(), lr=0.00001, weight_decay=0.00001) 87 | scheduler = optim.lr_scheduler.StepLR(optimiser, step_size=20, gamma=0.1) 88 | 89 | # Load checkpoint if directory is provided 90 | if checkpoint_dir: 91 | model_state, optimiser_state = torch.load( 92 | os.path.join(checkpoint_dir, "checkpoint")) 93 | model_SigWavNet.load_state_dict(model_state) 94 | optimiser.load_state_dict(optimiser_state) 95 | 96 | # Prepare dataloaders for cross-validation 97 | dataloaders = get_dataloaders(data, batch_size=config["batch_size"], num_splits=config["num_splits"]) 98 | 99 | # Initialize lists to track training and validation metrics 100 | losses_train = [] 101 | losses_validation = [] 102 | accuracy_train = [] 103 | accuracy_validation = [] 104 | total_train_acc = 0.0 105 | total_val_acc = 0.0 106 | 107 | pbar_update_1 = 1 # Progress bar update size for folds 108 | 109 | with tqdm(total=config["num_splits"]) as pbar_1: 110 | for fold, (train_loader, val_loader) in enumerate(dataloaders, 0): 111 | pbar_update = 1 / (len(train_loader) + len(val_loader)) # Progress bar update size for batches 112 | print(f'Fold {fold+1}/{config["num_splits"]}') 113 | optimiser = optim.Adam(model_SigWavNet.parameters(), lr=0.00001, weight_decay=0.00001) 114 | scheduler = optim.lr_scheduler.StepLR(optimiser, step_size=20, gamma=0.1) 115 | with tqdm(total=epoch_count) as pbar: 116 | 117 | for epoch in range(1, epoch_count + 1): 118 | 119 | model_SigWavNet.train() # Set model to training mode 120 | right = 0 # Track number of correct predictions 121 | 122 | h = model_SigWavNet.init_hidden(config["batch_size"]) # Initialize hidden states 123 | 124 | 125 | for batch_index, (data, target) in enumerate(train_loader): 126 | 127 | data = data.to(device) 128 | target = target.to(device) 129 | 130 | h = [i.data for i in h] # Detach hidden states 131 | 132 | output, h = model_SigWavNet(data, h) # Forward pass 133 | 134 | pred = get_probable_idx(output) # Get predicted classes 135 | right += nr_of_right(pred, target) # Update correct predictions count 136 | 137 | loss = criterion(output.squeeze(), target) # Compute loss 138 | 139 | optimiser.zero_grad() # Zero gradients 140 | loss.backward() # Backpropagation 141 | optimiser.step() # Update weights 142 | 143 | # Log training progress 144 | if batch_index % log_interval == 0: 145 | print(f"Train Epoch: {epoch} [{batch_index * len(data)}/{len(train_loader.dataset)} ({100. * batch_index / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}\tAccuracy: {right}/{len(train_loader.dataset)} ({100. * right / len(train_loader.dataset):.0f}%)") 146 | 147 | pbar.update(pbar_update) # Update progress bar 148 | 149 | losses_train.append(loss.item()) # Track training loss 150 | 151 | free_memory([data, target, output]+h) # Free up memory 152 | ################################################################################################### 153 | 154 | # Validation loop 155 | model_SigWavNet.eval() # Set model to evaluation mode 156 | right = 0 # Reset correct predictions count 157 | val_loss = 0.0 # Reset validation loss 158 | val_steps = 0 # Reset validation steps count 159 | h = model_SigWavNet.init_hidden(config["batch_size"]) # Reinitialize hidden states 160 | 161 | for batch_index, (data, target) in enumerate(val_loader): 162 | 163 | data = data.to(device) 164 | target = target.to(device) 165 | 166 | h = [i.data for i in h] # Detach hidden states 167 | 168 | output, h = model_SigWavNet(data, h) # Forward pass 169 | 170 | pred = get_probable_idx(output) # Get predicted classes 171 | 172 | right += nr_of_right(pred, target) # Update correct predictions count 173 | 174 | loss = criterion(output.squeeze(), target) # Compute loss 175 | 176 | val_loss += loss.item() # Accumulate validation loss 177 | val_steps += 1 # Increment validation steps 178 | 179 | pbar.update(pbar_update) # Update progress bar 180 | 181 | free_memory([data, target, output]+h) # Free up memory 182 | 183 | 184 | # Log validation progress 185 | print(f"\nValidation Epoch: {epoch} \tLoss: {loss.item():.6f}\tAccuracy: {right}/{len(val_loader.dataset)} ({100. * right / len(val_loader.dataset):.0f}%)\n") 186 | 187 | # Calculate validation accuracy 188 | acc = 100. * right / len(val_loader.dataset) 189 | accuracy_validation.append(acc) # Track validation accuracy 190 | 191 | losses_validation.append(loss.item()) # Track validation loss 192 | 193 | scheduler.step() # Update learning rate 194 | 195 | # Save checkpoint 196 | with tune.checkpoint_dir(epoch) as checkpoint_dir: 197 | path = os.path.join(checkpoint_dir, "checkpoint") # Define checkpoint path 198 | torch.save((model_SigWavNet.state_dict(), optimiser.state_dict()), path) # Save model and optimizer states 199 | 200 | tune.report(loss=(val_loss / val_steps), accuracy=acc) # Report metrics to Ray Tune 201 | pbar_1.update(pbar_update_1) # Update outer progress bar 202 | print("Finished Training !") # Indicate training completion 203 | 204 | 205 | 206 | def test(model, batch_size, data): 207 | 208 | """ 209 | Tests the given model on a test dataset. 210 | 211 | Parameters: 212 | - model: The trained model to be tested. 213 | - batch_size: The batch size to use for testing. 214 | - data: The test dataset. 215 | 216 | Returns: 217 | - Test set accuracy, predicted labels, and true labels. 218 | """ 219 | 220 | model.eval() # Set model to evaluation mode 221 | right = 0 # Reset correct predictions count 222 | 223 | # Compute mean and standard deviation for normalization 224 | train_mean, train_std = compute_precise_mean_std(data['path']) 225 | test_transform = MyTransformPipeline(train_mean=train_mean, train_std=train_std) # Initialize transformation pipeline 226 | test_set = MyDataset(test_ds['path'], test_ds['label'], test_transform) # Create test dataset 227 | 228 | # Setup device for testing 229 | device = "cpu" 230 | if torch.cuda.is_available(): 231 | device = "cuda:0" # Use first CUDA device 232 | pin_memory = True # Enable pinning for faster transfers to CUDA device 233 | 234 | # Initialize test loader 235 | test_loader = torch.utils.data.DataLoader( 236 | test_set, 237 | batch_size=batch_size, 238 | shuffle=False, 239 | drop_last=True, 240 | collate_fn=collate_fn, 241 | num_workers=num_workers, 242 | pin_memory=pin_memory, 243 | ) 244 | 245 | h = model.init_hidden(batch_size) # Initialize hidden states 246 | 247 | # Lists to store true and predicted labels 248 | y_true = [] 249 | y_pred = [] 250 | with torch.no_grad(): # Disable gradient computation 251 | for data, target in test_loader: 252 | 253 | data = data.to(device) 254 | target = target.to(device) 255 | 256 | targets = target.data.cpu().numpy() # Get true labels 257 | y_true.extend(targets) 258 | 259 | h = [i.data for i in h] # Detach hidden states 260 | 261 | output, h = model(data, h) # Forward pass 262 | 263 | 264 | pred = get_probable_idx(output) # Get predicted classes 265 | right += nr_of_right(pred, target) # Update correct predictions count 266 | 267 | output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy() # Get predicted labels 268 | y_pred.extend(output) 269 | 270 | # Print test set accuracy 271 | print(f"\nTest set accuracy: {right}/{len(test_loader.dataset)} ({100. * right / len(test_loader.dataset):.0f}%)\n") 272 | 273 | return (100. * right / len(test_loader.dataset)), y_pred, y_true # Return accuracy, predicted labels, and true labels 274 | 275 | 276 | 277 | 278 | def main(num_samples=10, max_num_epochs=10, gpus_per_trial=1): 279 | 280 | """ 281 | Main function to setup and run the model training and testing. 282 | 283 | Parameters: 284 | - num_samples: Number of hyperparameter samples to try. 285 | - max_num_epochs: Maximum number of epochs to train. 286 | - gpus_per_trial: Number of GPUs to use per trial. 287 | 288 | Returns: 289 | - None 290 | """ 291 | 292 | # Load data and set wavelet 293 | data, emotionclasses = load_data('IEMOCAP_dataset.csv') 294 | wt = pywt.Wavelet('db10') # Define wavelet 295 | 296 | # Print model details 297 | print('SigWavNet model') 298 | print('Model n_input', 1) 299 | print('Model n_output', len(emotionclasses)) 300 | 301 | # Define hyperparameter configuration 302 | config = { 303 | "n_input": tune.choice([1]), 304 | "hidden_dim": tune.choice([32, 64, 128]), 305 | "n_layers": tune.choice([3, 6, 9]), 306 | "n_output": tune.choice([len(emotionclasses)]), 307 | "weight_decay": tune.loguniform(1e-6, 1e-2), 308 | 'level': tune.choice([4, 5, 6, 7, 8, 9, 10]), 309 | 'trainHT': tune.choice([True, False]), 310 | 'initHT': tune.choice([0.5, 0.6, 0.7, 0.8, 0.9, 1]), 311 | 'kernTrainable': tune.choice([True, False]), 312 | 'kernelInit': np.array(wt.filter_bank[0]), 313 | 'alpha': tune.choice([10, 11, 12, 13, 14, 15]), 314 | 'mode': tune.choice(['CQF', 'PerLayer', 'PerFilter']), 315 | "lr": tune.loguniform(1e-4, 1e-1), 316 | "batch_size": tune.grid_search([2, 4, 8, 16, 32]), 317 | "num_splits": tune.choice([10]) 318 | } 319 | 320 | # Partition data for cross-validation 321 | data, test_ds, val_ds = get_dataset_partitions_pd(data,train_split=0.9, val_split=0, test_split=0.1, target_variable='label', data_source='source') 322 | 323 | # Set up Ray Tune scheduler and reporter 324 | scheduler = ASHAScheduler( 325 | metric="loss", 326 | mode="min", 327 | max_t=max_num_epochs, 328 | grace_period=1, 329 | reduction_factor=2) 330 | 331 | reporter = CLIReporter( 332 | metric_columns=["loss", "accuracy", "training_iteration"]) 333 | 334 | # Start Ray Tune run 335 | result = tune.run( 336 | tune.with_parameters(train_SigWavNet, data=data, max_num_epochs=max_num_epochs), 337 | resources_per_trial={"cpu": 12, "gpu": gpus_per_trial}, 338 | config=config, 339 | num_samples=num_samples, 340 | scheduler=scheduler, 341 | progress_reporter=reporter, 342 | local_dir=os.path.abspath(iemocap_experiments_folder+"/IEMOCAP_SigWavNet"), 343 | log_to_file=(os.path.abspath(iemocap_experiments_folder+"/IEMOCAP_SigWavNet_stdout.log"), os.path.abspath(iemocap_experiments_folder+"/IEMOCAP_SigWavNet_stderr.log")), 344 | name="IEMOCAP_SigWavNet", 345 | resume='AUTO') 346 | 347 | # Extract best trial information 348 | best_trial = result.get_best_trial("loss", "min", "last") 349 | print("Best trial config: {}".format(best_trial.config)) 350 | print("Best trial final validation loss: {}".format( 351 | best_trial.last_result["loss"])) 352 | print("Best trial final validation accuracy: {}".format( 353 | best_trial.last_result["accuracy"])) 354 | 355 | # Initialize best trained model with best trial configuration 356 | best_trained_model = SigWavNet(n_input=best_trial.config['n_input'], hidden_dim=best_trial.config['hidden_dim'], n_layers=best_trial.config['n_layers'] , n_output=best_trial.config['n_output'], inputSize=None, kernelInit=best_trial.config['kernelInit'], kernTrainable=best_trial.config['kernTrainable'], level=best_trial.config['level'], kernelsConstraint=best_trial.config['mode'], initHT=best_trial.config['initHT'], trainHT=best_trial.config['trainHT'], alpha=best_trial.config['alpha']) 357 | device = "cpu" 358 | if torch.cuda.is_available(): 359 | device = "cuda:0" # Use first CUDA device 360 | if gpus_per_trial > 1: 361 | best_trained_model = nn.DataParallel(best_trained_model) # Use DataParallel for multi-GPU 362 | best_trained_model.to(device) # Move model to the chosen device 363 | 364 | # Load the best model state 365 | best_checkpoint_dir = best_trial.checkpoint.value 366 | model_state, optimiser_state = torch.load(os.path.join(best_checkpoint_dir, "checkpoint")) 367 | best_trained_model.load_state_dict(model_state) 368 | 369 | # Test the best trained model 370 | SigWavNet_test_acc_result, y_pred, y_true = test(best_trained_model, best_trial.config["batch_size"], data) 371 | 372 | # Print test set accuracy 373 | print("Best trial test set accuracy: {}".format(SigWavNet_test_acc_result)) 374 | 375 | 376 | if __name__ == "__main__": 377 | # Main function call, can specify number of samples, epochs, and GPUs per trial 378 | main(num_samples=10, max_num_epochs=5, gpus_per_trial=1) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Title: Utility Functions for Speech Emotion Recognition (SigWavNet) 6 | Author: Alaa Nfissi 7 | Date: March 31, 2024 8 | Description: This file contains utility functions for preprocessing, loading data, and 9 | performing various auxiliary tasks for the speech emotion recognition project. 10 | """ 11 | 12 | import os 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | import torchaudio 19 | 20 | import pandas as pd 21 | import numpy as np 22 | import matplotlib.pyplot as plt 23 | import IPython.display as ipd 24 | from tqdm.notebook import tqdm 25 | import math 26 | from sklearn.metrics import confusion_matrix, classification_report 27 | import seaborn as sn 28 | torch.manual_seed(123) 29 | import random 30 | import pywt 31 | random.seed(123) 32 | 33 | from ray import tune 34 | from ray.tune import CLIReporter 35 | from ray.tune.schedulers import ASHAScheduler 36 | 37 | from torch import cuda 38 | import gc 39 | import inspect 40 | 41 | from sklearn.model_selection import StratifiedKFold 42 | 43 | d = 0 44 | device = torch.device(f"cuda:{d}") 45 | 46 | 47 | 48 | def load_data(data_path): 49 | """ 50 | Loads the dataset from a given CSV file path and processes it. 51 | 52 | Parameters: 53 | - data_path (str): Path to the CSV file containing the dataset. 54 | 55 | Returns: 56 | - Tuple containing the processed dataset DataFrame and a list of unique emotion classes. 57 | """ 58 | data_path = os.path.abspath(data_path) 59 | 60 | data = pd.read_csv(data_path) 61 | data['label'] = data['label'].replace('exc', 'hap') 62 | data = data[data['label'].isin(['ang', 'hap', 'neu', 'sad'])].reset_index() 63 | del data['index'] 64 | 65 | emotionclasses = sorted(list(data.label.unique())) 66 | 67 | return data, emotionclasses 68 | 69 | 70 | def get_dataset_partitions_pd(df, train_split=0.8, val_split=0.1, test_split=0.1, target_variable=None, data_source=None): 71 | 72 | """ 73 | Splits the dataset into training, validation, and test sets. 74 | 75 | Parameters: 76 | - df (DataFrame): The dataset to split. 77 | - train_split (float): Proportion of the dataset to use for training. 78 | - val_split (float): Proportion of the dataset to use for validation. 79 | - test_split (float): Proportion of the dataset to use for testing. 80 | - target_variable (str, optional): Name of the column containing the target variable for stratification. 81 | - data_source (str, optional): Name of the column containing the data source for stratification. 82 | 83 | Returns: 84 | - DataFrames for the training, validation, and test sets. 85 | """ 86 | 87 | assert (train_split + test_split + val_split) == 1 88 | 89 | # Only allows for equal validation and test splits 90 | #assert val_split == test_split 91 | # Shuffle 92 | df_sample = df.sample(frac=1, random_state=42) 93 | 94 | # Specify seed to always have the same split distribution between runs 95 | # If target variable is provided, generate stratified sets 96 | arr_list = [] 97 | if target_variable is not None and data_source is not None: 98 | grouped_df = df_sample.groupby([data_source, target_variable]) 99 | #arr_list = [np.split(g, [int(train_split * len(g)), int((1 - val_split) * len(g))]) for i, g in grouped_df] 100 | for i, g in grouped_df: 101 | if len(g) == 3: 102 | arr_list.append(np.split(g, 3)) 103 | else: 104 | arr_list.append(np.split(g, [int(train_split * len(g)), int((1 - val_split) * len(g))])) 105 | train_ds = pd.concat([t[0] for t in arr_list]) 106 | val_ds = pd.concat([v[1] for v in arr_list]) 107 | test_ds = pd.concat([t[2] for t in arr_list]) 108 | 109 | else: 110 | indices_or_sections = [int(train_split * len(df)), int((1 - val_split) * len(df))] 111 | train_ds, val_ds, test_ds = np.split(df_sample, indices_or_sections) 112 | 113 | return train_ds.reset_index(drop=True), val_ds.reset_index(drop=True), test_ds.reset_index(drop=True) 114 | 115 | 116 | 117 | class MyDataset(torch.utils.data.Dataset): 118 | 119 | """ 120 | Custom PyTorch Dataset for loading and processing the speech emotion recognition dataset. 121 | 122 | Attributes: 123 | - paths (list): List of file paths to the audio files. 124 | - labels (list): List of labels corresponding to the audio files. 125 | - transform (callable): A function/transform that takes in an audio file and returns a transformed version. 126 | """ 127 | 128 | def __init__(self, paths, labels, transform): 129 | self.files = paths 130 | self.labels = labels 131 | self.transform = transform 132 | def __getitem__(self, item): 133 | #print(self.files) 134 | file = self.files[item] 135 | label = self.labels[item] 136 | file, sampling_rate = torchaudio.load(file) 137 | file = file if file.shape[0] == 1 else file[0].unsqueeze(0) 138 | file = self.transform(file) 139 | 140 | return file, sampling_rate, label 141 | 142 | def __len__(self): 143 | return len(self.files) 144 | 145 | def compute_precise_mean_std(file_paths): 146 | 147 | """ 148 | Computes the mean and standard deviation of the waveforms in the dataset. 149 | 150 | Parameters: 151 | - file_paths (list): List of paths to the audio files in the dataset. 152 | 153 | Returns: 154 | - Tuple containing the global mean and standard deviation of the waveforms. 155 | """ 156 | 157 | sum_waveform = 0.0 158 | sum_squares = 0.0 159 | total_samples = 0 160 | 161 | for file_path in file_paths: 162 | waveform, _ = torchaudio.load(file_path) 163 | sum_waveform += waveform.sum() 164 | sum_squares += (waveform ** 2).sum() 165 | total_samples += waveform.numel() # Count total number of samples across all files 166 | 167 | # Compute global mean and std 168 | mean = sum_waveform / total_samples 169 | std = (sum_squares / total_samples - mean ** 2) ** 0.5 170 | 171 | return mean.item(), std.item() 172 | 173 | class MyTransformPipeline(nn.Module): 174 | 175 | """ 176 | Custom transform pipeline for processing audio data. 177 | 178 | Parameters: 179 | - train_mean (float): The mean of the training data. 180 | - train_std (float): The standard deviation of the training data. 181 | - input_freq (int): The original frequency of the audio data. 182 | - resample_freq (int): The target frequency to resample the audio data. 183 | """ 184 | 185 | def __init__( 186 | self, 187 | train_mean = 0, 188 | train_std = 1, 189 | input_freq=16000, 190 | resample_freq=16000, 191 | ): 192 | super().__init__() 193 | 194 | self.train_mean = train_mean 195 | self.train_std = train_std 196 | self.input_freq = input_freq 197 | self.resample_freq = resample_freq 198 | 199 | self.resample = torchaudio.transforms.Resample(orig_freq=self.input_freq, new_freq=self.resample_freq).to(device) 200 | 201 | def forward(self, waveform: torch.Tensor) -> torch.Tensor: 202 | # Resample the input 203 | waveform = waveform.to(device) 204 | resampled = self.resample(waveform) 205 | normalized_waveform = (resampled - self.train_mean) / self.train_std 206 | 207 | return normalized_waveform 208 | 209 | 210 | def count_parameters(model): 211 | 212 | """ 213 | Counts the number of trainable parameters in a model. 214 | 215 | Parameters: 216 | - model (torch.nn.Module): The model to count parameters for. 217 | 218 | Returns: 219 | - The number of trainable parameters. 220 | """ 221 | 222 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 223 | 224 | 225 | if device == f"cuda{d}": 226 | num_workers = 10 227 | pin_memory = True 228 | else: 229 | num_workers = 0 230 | pin_memory = False 231 | 232 | 233 | def get_dataloaders(data, batch_size=32, num_splits=5, stratify=True): 234 | 235 | """ 236 | Creates dataloaders for cross-validation, with optional stratification. 237 | 238 | Parameters: 239 | - data (DataFrame): The dataset to be loaded into the dataloaders. 240 | - batch_size (int): Size of batches. 241 | - num_splits (int): Number of folds for cross-validation. 242 | - stratify (bool): Whether to stratify the folds based on labels. 243 | 244 | Returns: 245 | - List of tuples containing train and validation dataloaders for each fold. 246 | """ 247 | 248 | if stratify: 249 | kf = StratifiedKFold(n_splits=num_splits) 250 | split_method = kf.split(data, data['label']) 251 | else: 252 | kf = KFold(n_splits=num_splits) 253 | split_method = kf.split(data) 254 | 255 | dataloaders = [] 256 | 257 | for train_idx, val_idx in split_method: 258 | train_data, val_data = data.iloc[train_idx], data.iloc[val_idx] 259 | 260 | train_data = train_data.reset_index(drop=True) 261 | val_data = val_data.reset_index(drop=True) 262 | 263 | 264 | train_mean, train_std = compute_precise_mean_std(train_data['path']) 265 | 266 | transform = MyTransformPipeline(input_freq=16000, resample_freq=8000, train_mean = train_mean, train_std = train_std) 267 | 268 | transform.to(device) 269 | 270 | train_dataset = MyDataset(train_data['path'], train_data['label'], transform=transform) 271 | val_dataset = MyDataset(val_data['path'], val_data['label'], transform=transform) 272 | 273 | # Create dataloaders for train and validation sets 274 | train_loader = torch.utils.data.DataLoader( 275 | train_dataset, 276 | batch_size=batch_size, 277 | shuffle=True, 278 | drop_last=True, 279 | collate_fn=collate_fn, 280 | num_workers=num_workers, 281 | pin_memory=pin_memory, 282 | ) 283 | val_loader = torch.utils.data.DataLoader( 284 | val_dataset, 285 | batch_size=batch_size, 286 | shuffle=False, 287 | drop_last=True, 288 | collate_fn=collate_fn, 289 | num_workers=num_workers, 290 | pin_memory=pin_memory, 291 | ) 292 | 293 | dataloaders.append((train_loader, val_loader)) 294 | 295 | return dataloaders 296 | 297 | _, emotionclasses = load_data('IEMOCAP_dataset.csv') 298 | 299 | 300 | def index_to_emotionclass(index): 301 | """ 302 | Converts an index to an emotion class. 303 | 304 | Parameters: 305 | - index (int): Index of the emotion class in the list of classes. 306 | 307 | Returns: 308 | - The name of the emotion class corresponding to the given index. 309 | """ 310 | 311 | return emotionclasses[index] 312 | 313 | def emotionclass_to_index(emotion): 314 | 315 | """ 316 | Converts an emotion class to its corresponding index. 317 | 318 | Parameters: 319 | - emotion (str): The emotion class. 320 | 321 | Returns: 322 | - Index of the emotion class in the list of classes. 323 | """ 324 | 325 | return torch.tensor(emotionclasses.index(emotion)) 326 | 327 | def pad_sequence(batch): 328 | 329 | """ 330 | Pads a batch of tensors to the same length with zeros. 331 | 332 | Parameters: 333 | - batch (list of Tensor): The batch of tensors to pad. 334 | 335 | Returns: 336 | - A tensor containing the padded batch. 337 | """ 338 | 339 | batch = [item.t() for item in batch] 340 | 341 | batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.) 342 | return batch.permute(0, 2, 1) 343 | 344 | def collate_fn(batch): 345 | 346 | """ 347 | Custom collate function to process batches of data. 348 | 349 | Parameters: 350 | - batch (list): A batch of data. 351 | 352 | Returns: 353 | - Processed batch of tensors and targets. 354 | """ 355 | 356 | tensors, targets = [], [] 357 | 358 | # Gather in lists, and encode wordclasses as indices 359 | for waveform, _, emotionclass, *_ in batch: 360 | tensors += [waveform] 361 | targets += [emotionclass_to_index(emotionclass)] 362 | 363 | # Group the list of tensors into a batched tensor 364 | tensors = pad_sequence(tensors) 365 | # stack - Concatenates a sequence of tensors along a new dimension 366 | targets = torch.stack(targets) 367 | 368 | return tensors, targets 369 | 370 | 371 | def get_less_used_gpu(gpus=None, debug=False): 372 | 373 | """ 374 | Finds the least utilized GPU for use. 375 | 376 | Parameters: 377 | - gpus (list): List of GPUs to consider. If None, considers all available GPUs. 378 | - debug (bool): Whether to print debug information. 379 | 380 | Returns: 381 | - ID of the least used GPU. 382 | """ 383 | 384 | if gpus is None: 385 | warn = 'Falling back to default: all gpus' 386 | gpus = range(cuda.device_count()) 387 | elif isinstance(gpus, str): 388 | gpus = [int(el) for el in gpus.split(',')] 389 | 390 | # check gpus arg VS available gpus 391 | sys_gpus = list(range(cuda.device_count())) 392 | if len(gpus) > len(sys_gpus): 393 | gpus = sys_gpus 394 | warn = f'WARNING: Specified {len(gpus)} gpus, but only {cuda.device_count()} available. Falling back to default: all gpus.\nIDs:\t{list(gpus)}' 395 | elif set(gpus).difference(sys_gpus): 396 | # take correctly specified and add as much bad specifications as unused system gpus 397 | available_gpus = set(gpus).intersection(sys_gpus) 398 | unavailable_gpus = set(gpus).difference(sys_gpus) 399 | unused_gpus = set(sys_gpus).difference(gpus) 400 | gpus = list(available_gpus) + list(unused_gpus)[:len(unavailable_gpus)] 401 | warn = f'GPU ids {unavailable_gpus} not available. Falling back to {len(gpus)} device(s).\nIDs:\t{list(gpus)}' 402 | 403 | cur_allocated_mem = {} 404 | cur_cached_mem = {} 405 | max_allocated_mem = {} 406 | max_cached_mem = {} 407 | for i in gpus: 408 | cur_allocated_mem[i] = cuda.memory_allocated(i) 409 | cur_cached_mem[i] = cuda.memory_reserved(i) 410 | max_allocated_mem[i] = cuda.max_memory_allocated(i) 411 | max_cached_mem[i] = cuda.max_memory_reserved(i) 412 | min_allocated = min(cur_allocated_mem, key=cur_allocated_mem.get) 413 | if debug: 414 | print(warn) 415 | print('Current allocated memory:', {f'cuda:{k}': v for k, v in cur_allocated_mem.items()}) 416 | print('Current reserved memory:', {f'cuda:{k}': v for k, v in cur_cached_mem.items()}) 417 | print('Maximum allocated memory:', {f'cuda:{k}': v for k, v in max_allocated_mem.items()}) 418 | print('Maximum reserved memory:', {f'cuda:{k}': v for k, v in max_cached_mem.items()}) 419 | print('Suggested GPU:', min_allocated) 420 | return min_allocated 421 | 422 | 423 | def free_memory(to_delete: list, debug=False): 424 | 425 | """ 426 | Frees up memory by deleting specified variables and collecting garbage. 427 | 428 | Parameters: 429 | - to_delete (list): List of variable names to delete. 430 | - debug (bool): Whether to print debug information before and after freeing memory. 431 | """ 432 | 433 | calling_namespace = inspect.currentframe().f_back 434 | if debug: 435 | print('Before:') 436 | get_less_used_gpu(debug=True) 437 | 438 | for _var in to_delete: 439 | calling_namespace.f_locals.pop(_var, None) 440 | gc.collect() 441 | cuda.empty_cache() 442 | if debug: 443 | print('After:') 444 | get_less_used_gpu(debug=True) 445 | 446 | 447 | class FocalLoss(nn.Module): 448 | 449 | """ 450 | Implementation of the Focal Loss as a PyTorch module. 451 | 452 | Parameters: 453 | - alpha (Tensor): Weighting factor for the positive class. 454 | - gamma (float): Focusing parameter to adjust the rate at which easy examples contribute to the loss. 455 | """ 456 | 457 | def __init__(self, alpha=None, gamma=2): 458 | super(FocalLoss, self).__init__() 459 | self.alpha = alpha.to(device) 460 | self.gamma = gamma 461 | 462 | def forward(self, inputs, targets): 463 | ce_loss = F.cross_entropy(inputs, targets, reduction='mean') 464 | pt = torch.exp(-ce_loss) 465 | loss = (self.alpha[targets] * (1 - pt) ** self.gamma * ce_loss).mean() 466 | return loss 467 | 468 | 469 | def nr_of_right(pred, target): 470 | 471 | """ 472 | Counts the number of correct predictions. 473 | 474 | Parameters: 475 | - pred (Tensor): Predicted labels. 476 | - target (Tensor): True labels. 477 | 478 | Returns: 479 | - Number of correct predictions. 480 | """ 481 | 482 | return pred.squeeze().eq(target).sum().item() 483 | 484 | def get_probable_idx(tensor): 485 | 486 | """ 487 | Finds the indices of the most probable class for each element in the batch. 488 | 489 | Parameters: 490 | - tensor (Tensor): Tensor containing class probabilities for each element. 491 | 492 | Returns: 493 | - Tensor of indices for the most probable class. 494 | """ 495 | 496 | return tensor.argmax(dim=-1) 497 | 498 | 499 | def print_confusion_matrix(confusion_matrix, class_names, figsize = (10,7), fontsize=14, normalize=True): 500 | 501 | """ 502 | Prints a confusion matrix using seaborn. 503 | 504 | Parameters: 505 | - confusion_matrix (ndarray): The confusion matrix to print. 506 | - class_names (list): List of class names corresponding to the indices of the confusion matrix. 507 | - figsize (tuple): Size of the figure. 508 | - fontsize (int): Font size for the labels. 509 | - normalize (bool): Whether to normalize the values in the confusion matrix. 510 | """ 511 | 512 | fig = plt.figure(figsize=figsize) 513 | if normalize: 514 | confusion_matrix_1 = (confusion_matrix.astype('float') / confusion_matrix.sum(axis=1)[:, np.newaxis])*100 515 | print("Normalized confusion matrix") 516 | else: 517 | confusion_matrix_1 = confusion_matrix 518 | print('Confusion matrix, without normalization') 519 | df_cm = pd.DataFrame( 520 | confusion_matrix_1, index=class_names, columns=class_names 521 | ) 522 | labels = (np.asarray(["{:1.2f} % \n ({})".format(value, value_1) for value, value_1 in zip(confusion_matrix_1.flatten(),confusion_matrix.flatten())])).reshape(confusion_matrix.shape) 523 | try: 524 | heatmap = sn.heatmap(df_cm, cmap="Blues", annot=labels, fmt='' if normalize else 'd') 525 | except ValueError: 526 | raise ValueError("Confusion matrix values must be integers.") 527 | 528 | heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize) 529 | heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize) 530 | plt.ylabel('True label') 531 | plt.xlabel('Predicted label') --------------------------------------------------------------------------------