├── tests └── requirements.txt ├── requirements.txt ├── LICENSE ├── .github └── workflows │ └── test.yml ├── .gitignore ├── README.md └── notebook.ipynb /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | audiofile >=1.1.0 2 | nbmake 3 | pytest 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | audb 2 | audinterface 3 | audmetric 4 | audonnx 5 | audplot 6 | notebook 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 audEERING 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | os: [ ubuntu-20.04, windows-latest, macOS-latest ] 16 | python-version: [3.8] 17 | 18 | steps: 19 | - uses: actions/checkout@v2 20 | 21 | - name: Cache emodb 22 | uses: actions/cache@v2 23 | with: 24 | path: cache/emodb 25 | key: emodb-1.3.0 26 | 27 | - name: Cache model 28 | uses: actions/cache@v2 29 | with: 30 | path: cache/model.zip 31 | key: model-1.1.1 32 | 33 | - name: Cache predictions 34 | uses: actions/cache@v2 35 | with: 36 | path: cache/-1033597102444974303.pkl 37 | key: predictions-emodb-1.3.0-model-1.1.1 38 | 39 | - name: Set up Python ${{ matrix.python-version }} 40 | uses: actions/setup-python@v2 41 | with: 42 | python-version: ${{ matrix.python-version }} 43 | 44 | - name: Prepare Ubuntu 45 | run: | 46 | sudo apt-get update 47 | sudo apt-get install -y ffmpeg mediainfo 48 | if: matrix.os == 'ubuntu-latest' || matrix.os == 'ubuntu-20.04' 49 | 50 | - name: Prepare OSX 51 | run: brew install ffmpeg mediainfo 52 | if: matrix.os == 'macOS-latest' 53 | 54 | - name: Windows 55 | run: choco install ffmpeg mediainfo-cli 56 | if: matrix.os == 'windows-latest' 57 | 58 | - name: Install dependencies 59 | run: | 60 | python -V 61 | python -m pip install --upgrade pip 62 | pip install -r requirements.txt 63 | pip install -r tests/requirements.txt 64 | 65 | - name: Test with pytest 66 | run: | 67 | python -m pytest --nbmake --nbmake-timeout=3000 notebook.ipynb 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # How to use our public age and gender model 2 | 3 | An introduction to our model for 4 | age and gender prediction based on 5 | [wav2vec 2.0](https://ai.facebook.com/blog/wav2vec-20-learning-the-structure-of-speech-from-raw-audio/). 6 | The model is available from 7 | [doi:10.5281/zenodo.7761387](https://doi.org/10.5281/zenodo.7761387) 8 | and released under 9 | [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/). 10 | The model was created 11 | by fine-tuning the pre-trained 12 | [wav2vec2-large-robust](https://huggingface.co/facebook/wav2vec2-large-robust) 13 | model on 14 | [aGender](https://paperswithcode.com/dataset/agender), 15 | [Mozilla Common Voice](https://commonvoice.mozilla.org/), 16 | [Timit](https://catalog.ldc.upenn.edu/LDC93s1) and 17 | [Voxceleb 2](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html). 18 | We provide two variants of the model: 19 | one with all 24 transformer layers and 20 | a stripped-down version with six transformer layers. 21 | The models were exported to 22 | [ONNX](https://onnx.ai/). 23 | The original 24 | [Torch](https://pytorch.org/) 25 | model is hosted at Hugging Face: 26 | [6 layers](https://huggingface.co/audeering/wav2vec2-large-robust-6-ft-age-gender) 27 | and 28 | [24 layers](https://huggingface.co/audeering/wav2vec2-large-robust-24-ft-age-gender). 29 | Further details are given in the associated 30 | [paper](https://arxiv.org/abs/2306.16962) 31 | and [notebook](./notebook.ipynb). 32 | 33 | ## License 34 | 35 | The model can be used for non-commercial purposes, 36 | see [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/). 37 | For commercial usage, 38 | a license for 39 | [devAIce](https://www.audeering.com/devaice/) 40 | must be obtained. 41 | The source code in this GitHub repository 42 | is released under the following 43 | [license](./LICENSE). 44 | 45 | ## Quick start 46 | 47 | Create / activate Python virtual environment and install 48 | [audonnx](https://github.com/audeering/audonnx). 49 | 50 | ``` 51 | $ pip install audonnx 52 | ``` 53 | 54 | Load the model with six layers and test on random signal. 55 | 56 | ```python 57 | import audeer 58 | import audonnx 59 | import numpy as np 60 | 61 | 62 | url = 'https://zenodo.org/record/7761387/files/w2v2-L-robust-6-age-gender.25c844af-1.1.1.zip' 63 | cache_root = audeer.mkdir('cache') 64 | model_root = audeer.mkdir('model') 65 | 66 | archive_path = audeer.download_url(url, cache_root, verbose=True) 67 | audeer.extract_archive(archive_path, model_root) 68 | model = audonnx.load(model_root) 69 | 70 | sampling_rate = 16000 71 | signal = np.random.normal(size=sampling_rate).astype(np.float32) 72 | model(signal, sampling_rate) 73 | ``` 74 | ``` 75 | {'hidden_states': array([[ 0.02783544, 0.01402022, 0.03839185, ..., 0.00786646, 76 | -0.09332313, 0.0915948 ]], dtype=float32), 77 | 'logits_age': array([[0.3961048]], dtype=float32), 78 | 'logits_gender': array([[ 0.32810774, -0.56528044, 0.0317882 ]], dtype=float32)} 79 | ``` 80 | 81 | The 'hidden_states' are the pooled states of the last transformer layer, 82 | 'logits_age' provides scores for age in a range of approximately 0...1 (== 100 years) 83 | and 'logits_gender' expresses the confidence for being female, male or child. 84 | 85 | ## Tutorial 86 | 87 | For a detailed introduction, please check out the [notebook](./notebook.ipynb). 88 | 89 | ```bash 90 | $ pip install -r requirements.txt 91 | $ jupyter notebook notebook.ipynb 92 | ``` 93 | 94 | ## Citation 95 | 96 | If you use our model in your own work, please cite the following 97 | [paper](https://arxiv.org/abs/2306.16962): 98 | 99 | ``` bibtex 100 | @inproceedings{, 101 | author = {Felix Burkhardt and Johannes Wagner and Hagen Wierstorf and Florian Eyben and Björn Schuller}, 102 | editor = {Peter Jax and Sebastian Mölller}, 103 | journal = {15th ITG conference on Speech Communication}, 104 | title = {Speech-based Age and Gender Prediction with Transformers}, 105 | year = {2023}, 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "fc82a8e4", 6 | "metadata": {}, 7 | "source": [ 8 | "# How to use our wav2vec 2.0 model to predict age and gender" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "23f9dd5c", 14 | "metadata": {}, 15 | "source": [ 16 | "In the following we present a hands-on to our model for age and gender prediction based on wav2vec 2.0. The model is publicly available for non-commercial usage from https://doi.org/10.5281/zenodo.7761387." 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "313f2e2f", 22 | "metadata": { 23 | "hide_input": false 24 | }, 25 | "source": [ 26 | "## Load model" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "id": "af91aa2a", 32 | "metadata": {}, 33 | "source": [ 34 | "We start by downloading and unpacking the model (we will use the six layer version). This will get us two files, a binary ONNX file containing the model weights and a YAML file with meta information about the model." 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 1, 40 | "id": "5fa18d66", 41 | "metadata": { 42 | "scrolled": true 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "import os\n", 47 | "\n", 48 | "import audeer\n", 49 | "\n", 50 | "\n", 51 | "model_root = 'model'\n", 52 | "cache_root = 'cache'\n", 53 | "\n", 54 | "\n", 55 | "audeer.mkdir(cache_root)\n", 56 | "def cache_path(file):\n", 57 | " return os.path.join(cache_root, file)\n", 58 | "\n", 59 | "\n", 60 | "url = 'https://zenodo.org/record/7761387/files/w2v2-L-robust-6-age-gender.25c844af-1.1.1.zip'\n", 61 | "dst_path = cache_path('model.zip')\n", 62 | "\n", 63 | "if not os.path.exists(dst_path):\n", 64 | " audeer.download_url(url, dst_path, verbose=True)\n", 65 | " \n", 66 | "if not os.path.exists(model_root):\n", 67 | " audeer.extract_archive(dst_path, model_root, verbose=True)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "id": "4b5c7928", 73 | "metadata": {}, 74 | "source": [ 75 | "The package [audonnx](https://github.com/audeering/audonnx) helps us to load the model. Printing the model lists the input and output nodes. Since the model operates on the raw audio stream, we have a single input node called 'signal', which expects a mono signal with a sampling rate of 16000 Hz. We also see that the model has two output nodes: 'hidden_states', which gives us access to the pooled states of the last transformer layer, 'logits_age', which provides scores for age in a range of approximately 0...1 (== 100 years) and 'logits_gender', which expresses the confidence for being female, male or child." 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 2, 81 | "id": "f0b1a074", 82 | "metadata": { 83 | "scrolled": true 84 | }, 85 | "outputs": [ 86 | { 87 | "data": { 88 | "text/plain": [ 89 | "Input:\n", 90 | " signal:\n", 91 | " shape: [1, -1]\n", 92 | " dtype: tensor(float)\n", 93 | " transform: None\n", 94 | "Output:\n", 95 | " hidden_states:\n", 96 | " shape: [1, 1024]\n", 97 | " dtype: tensor(float)\n", 98 | " labels: [hidden_states-0, hidden_states-1, hidden_states-2, (...), hidden_states-1021,\n", 99 | " hidden_states-1022, hidden_states-1023]\n", 100 | " logits_age:\n", 101 | " shape: [1, 1]\n", 102 | " dtype: tensor(float)\n", 103 | " labels: [age]\n", 104 | " logits_gender:\n", 105 | " shape: [1, 3]\n", 106 | " dtype: tensor(float)\n", 107 | " labels: [female, male, child]" 108 | ] 109 | }, 110 | "execution_count": 2, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | } 114 | ], 115 | "source": [ 116 | "import audonnx\n", 117 | "\n", 118 | "\n", 119 | "model = audonnx.load(model_root)\n", 120 | "model" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "573cf22f", 126 | "metadata": {}, 127 | "source": [ 128 | "As a test, we call the model with some white noise. Note that we have to force the data type of the signal to 32-bit floating point precision. As result we get a dictionary with predictions for every output node." 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 3, 134 | "id": "9ef10804", 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "data": { 139 | "text/plain": [ 140 | "{'hidden_states': array([[ 0.02783544, 0.01402022, 0.03839185, ..., 0.00786646,\n", 141 | " -0.09332313, 0.0915948 ]], dtype=float32),\n", 142 | " 'logits_age': array([[0.3961048]], dtype=float32),\n", 143 | " 'logits_gender': array([[ 0.32810774, -0.56528044, 0.0317882 ]], dtype=float32)}" 144 | ] 145 | }, 146 | "execution_count": 3, 147 | "metadata": {}, 148 | "output_type": "execute_result" 149 | } 150 | ], 151 | "source": [ 152 | "import numpy as np\n", 153 | "\n", 154 | "\n", 155 | "np.random.seed(0)\n", 156 | "\n", 157 | "sampling_rate = 16000\n", 158 | "signal = np.random.normal(\n", 159 | " size=sampling_rate,\n", 160 | ").astype(np.float32)\n", 161 | "\n", 162 | "model(signal, sampling_rate)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "id": "c016311c", 168 | "metadata": {}, 169 | "source": [ 170 | "## Predict age and gender" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "id": "f17622eb", 176 | "metadata": {}, 177 | "source": [ 178 | "A more advanced way of interfacing the model is offered by [audinterface](https://github.com/audeering/audinterface). Especially, the class [Feature](https://audeering.github.io/audinterface/api.html#feature) comes in handy, as it has the option to assign names to the output dimensions. To create the interface, we simply pass our callable model object as processing function. Since we are only interested in the scores for age and gender, we pass 'logits_age' and 'logits_gender' as an additional key word argument to `outputs`. To concatenate the outputs we set `concat` to `True`. And we enable automatic resampling in case the expected sampling rate of the model is not matched." 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 4, 184 | "id": "877be883", 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "import audinterface\n", 189 | "\n", 190 | "\n", 191 | "outputs = ['logits_age', 'logits_gender']\n", 192 | "interface = audinterface.Feature(\n", 193 | " model.labels(outputs),\n", 194 | " process_func=model,\n", 195 | " process_func_args={\n", 196 | " 'outputs': outputs,\n", 197 | " 'concat': True,\n", 198 | " },\n", 199 | " sampling_rate=sampling_rate,\n", 200 | " resample=True, \n", 201 | " verbose=True,\n", 202 | ")" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "id": "9ead873c", 208 | "metadata": {}, 209 | "source": [ 210 | "When we pass the signal to the interface, we get as result a table with proper column labels." 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 5, 216 | "id": "6f001a87", 217 | "metadata": { 218 | "scrolled": true 219 | }, 220 | "outputs": [ 221 | { 222 | "data": { 223 | "text/html": [ 224 | "
\n", 225 | "\n", 238 | "\n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | "
agefemalemalechild
startend
0 days0 days 00:00:010.3961050.328108-0.565280.031788
\n", 268 | "
" 269 | ], 270 | "text/plain": [ 271 | " age female male child\n", 272 | "start end \n", 273 | "0 days 0 days 00:00:01 0.396105 0.328108 -0.56528 0.031788" 274 | ] 275 | }, 276 | "execution_count": 5, 277 | "metadata": {}, 278 | "output_type": "execute_result" 279 | } 280 | ], 281 | "source": [ 282 | "interface.process_signal(signal, sampling_rate)" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "id": "c411a9ff", 288 | "metadata": {}, 289 | "source": [ 290 | "## Evaluate model on emodb" 291 | ] 292 | }, 293 | { 294 | "cell_type": "markdown", 295 | "id": "68188d94", 296 | "metadata": {}, 297 | "source": [ 298 | "The Berlin Database of Emotional Speech ([Emo-DB](emodb.bilderbar.info)) is a well known speech databases with emotional utterances by different actors. It contains information on speaker age and gender, and expressions of emotional and non-emotional sentences in German. It will allow us to investigate if a) the model works on languages other than English b) it still works for affective speech. To get the database we use [audb](https://github.com/audeering/audb), a package to manage annotated media files. When we load the data, audb takes care of caching and converting the files to the desired format. Annotations are organized as tables in [audformat](https://github.com/audeering/audformat). In the following experiment we use columns with speaker information and emotional labels." 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 6, 304 | "id": "4ef6e326", 305 | "metadata": { 306 | "scrolled": true 307 | }, 308 | "outputs": [ 309 | { 310 | "name": "stdout", 311 | "output_type": "stream", 312 | "text": [ 313 | "Get: emodb v1.3.0\n", 314 | "Cache: /media/jwagner/Data/Git/how-to/w2v2-age-gender-how-to/cache/emodb/1.3.0/fe182b91\n" 315 | ] 316 | }, 317 | { 318 | "name": "stderr", 319 | "output_type": "stream", 320 | "text": [ 321 | " \r" 322 | ] 323 | }, 324 | { 325 | "data": { 326 | "text/html": [ 327 | "
\n", 328 | "\n", 341 | "\n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | "
agegenderemotion
file
wav/03a01Fa.wav31malehappiness
wav/03a01Nc.wav31maleneutral
wav/03a01Wa.wav31maleanger
wav/03a02Fc.wav31malehappiness
wav/03a02Nc.wav31maleneutral
............
wav/16b10Lb.wav31femaleboredom
wav/16b10Tb.wav31femalesadness
wav/16b10Td.wav31femalesadness
wav/16b10Wa.wav31femaleanger
wav/16b10Wb.wav31femaleanger
\n", 425 | "

535 rows × 3 columns

\n", 426 | "
" 427 | ], 428 | "text/plain": [ 429 | " age gender emotion\n", 430 | "file \n", 431 | "wav/03a01Fa.wav 31 male happiness\n", 432 | "wav/03a01Nc.wav 31 male neutral\n", 433 | "wav/03a01Wa.wav 31 male anger\n", 434 | "wav/03a02Fc.wav 31 male happiness\n", 435 | "wav/03a02Nc.wav 31 male neutral\n", 436 | "... ... ... ...\n", 437 | "wav/16b10Lb.wav 31 female boredom\n", 438 | "wav/16b10Tb.wav 31 female sadness\n", 439 | "wav/16b10Td.wav 31 female sadness\n", 440 | "wav/16b10Wa.wav 31 female anger\n", 441 | "wav/16b10Wb.wav 31 female anger\n", 442 | "\n", 443 | "[535 rows x 3 columns]" 444 | ] 445 | }, 446 | "execution_count": 6, 447 | "metadata": {}, 448 | "output_type": "execute_result" 449 | } 450 | ], 451 | "source": [ 452 | "import audb\n", 453 | "import audformat\n", 454 | "\n", 455 | "\n", 456 | "db = audb.load(\n", 457 | " 'emodb',\n", 458 | " version='1.3.0',\n", 459 | " format='wav',\n", 460 | " mixdown=True,\n", 461 | " sampling_rate=16000,\n", 462 | " full_path=False, \n", 463 | " cache_root=cache_root,\n", 464 | " verbose=True,\n", 465 | ")\n", 466 | "age = db['files']['speaker'].get(map='age')\n", 467 | "gender = db['files']['speaker'].get(map='gender')\n", 468 | "emotion = db['emotion']['emotion'].get()\n", 469 | "\n", 470 | "df = audformat.utils.concat([age, gender, emotion])\n", 471 | "df" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "id": "04a2c177", 477 | "metadata": {}, 478 | "source": [ 479 | "The interface we created earlier offers us a convenient way to run the model directly on the index of a table. Note that we cache the feature once extracted to avoid re-calculation." 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": 7, 485 | "id": "f7bca356", 486 | "metadata": {}, 487 | "outputs": [ 488 | { 489 | "data": { 490 | "text/html": [ 491 | "
\n", 492 | "\n", 505 | "\n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | "
agefemalemalechild
filestartend
wav/03a01Fa.wav0 days0 days 00:00:01.8982500.4405411.4648814.078879-4.503115
wav/03a01Nc.wav0 days0 days 00:00:01.6112500.262211-1.3949646.082706-3.376553
wav/03a01Wa.wav0 days0 days 00:00:01.8778125000.4178280.1943855.651813-4.512995
wav/03a02Fc.wav0 days0 days 00:00:02.0062500.3111201.9071801.848307-3.264290
wav/03a02Nc.wav0 days0 days 00:00:01.4398125000.287216-1.3020276.188240-3.562538
.....................
wav/16b10Lb.wav0 days0 days 00:00:03.4426875000.3998406.157269-1.738880-4.453460
wav/16b10Tb.wav0 days0 days 00:00:03.5006250.3918745.663276-0.568825-4.939527
wav/16b10Td.wav0 days0 days 00:00:03.9341875000.3541475.773487-1.090130-4.626004
wav/16b10Wa.wav0 days0 days 00:00:02.4141250.4144064.1733812.015708-5.470536
wav/16b10Wb.wav0 days0 days 00:00:02.5224999990.3885603.8740071.189142-4.586453
\n", 628 | "

535 rows × 4 columns

\n", 629 | "
" 630 | ], 631 | "text/plain": [ 632 | " age female \\\n", 633 | "file start end \n", 634 | "wav/03a01Fa.wav 0 days 0 days 00:00:01.898250 0.440541 1.464881 \n", 635 | "wav/03a01Nc.wav 0 days 0 days 00:00:01.611250 0.262211 -1.394964 \n", 636 | "wav/03a01Wa.wav 0 days 0 days 00:00:01.877812500 0.417828 0.194385 \n", 637 | "wav/03a02Fc.wav 0 days 0 days 00:00:02.006250 0.311120 1.907180 \n", 638 | "wav/03a02Nc.wav 0 days 0 days 00:00:01.439812500 0.287216 -1.302027 \n", 639 | "... ... ... \n", 640 | "wav/16b10Lb.wav 0 days 0 days 00:00:03.442687500 0.399840 6.157269 \n", 641 | "wav/16b10Tb.wav 0 days 0 days 00:00:03.500625 0.391874 5.663276 \n", 642 | "wav/16b10Td.wav 0 days 0 days 00:00:03.934187500 0.354147 5.773487 \n", 643 | "wav/16b10Wa.wav 0 days 0 days 00:00:02.414125 0.414406 4.173381 \n", 644 | "wav/16b10Wb.wav 0 days 0 days 00:00:02.522499999 0.388560 3.874007 \n", 645 | "\n", 646 | " male child \n", 647 | "file start end \n", 648 | "wav/03a01Fa.wav 0 days 0 days 00:00:01.898250 4.078879 -4.503115 \n", 649 | "wav/03a01Nc.wav 0 days 0 days 00:00:01.611250 6.082706 -3.376553 \n", 650 | "wav/03a01Wa.wav 0 days 0 days 00:00:01.877812500 5.651813 -4.512995 \n", 651 | "wav/03a02Fc.wav 0 days 0 days 00:00:02.006250 1.848307 -3.264290 \n", 652 | "wav/03a02Nc.wav 0 days 0 days 00:00:01.439812500 6.188240 -3.562538 \n", 653 | "... ... ... \n", 654 | "wav/16b10Lb.wav 0 days 0 days 00:00:03.442687500 -1.738880 -4.453460 \n", 655 | "wav/16b10Tb.wav 0 days 0 days 00:00:03.500625 -0.568825 -4.939527 \n", 656 | "wav/16b10Td.wav 0 days 0 days 00:00:03.934187500 -1.090130 -4.626004 \n", 657 | "wav/16b10Wa.wav 0 days 0 days 00:00:02.414125 2.015708 -5.470536 \n", 658 | "wav/16b10Wb.wav 0 days 0 days 00:00:02.522499999 1.189142 -4.586453 \n", 659 | "\n", 660 | "[535 rows x 4 columns]" 661 | ] 662 | }, 663 | "execution_count": 7, 664 | "metadata": {}, 665 | "output_type": "execute_result" 666 | } 667 | ], 668 | "source": [ 669 | "pred = interface.process_index(\n", 670 | " df.index,\n", 671 | " root=db.root,\n", 672 | " cache_root=cache_root,\n", 673 | ")\n", 674 | "pred" 675 | ] 676 | }, 677 | { 678 | "cell_type": "markdown", 679 | "id": "c2e66316", 680 | "metadata": {}, 681 | "source": [ 682 | "We multiply the normalized age predictions by 100." 683 | ] 684 | }, 685 | { 686 | "cell_type": "code", 687 | "execution_count": 8, 688 | "id": "8333fb54", 689 | "metadata": {}, 690 | "outputs": [], 691 | "source": [ 692 | "pred_age = pred.age * 100" 693 | ] 694 | }, 695 | { 696 | "cell_type": "markdown", 697 | "id": "35609542", 698 | "metadata": {}, 699 | "source": [ 700 | "And determine the winning gender class." 701 | ] 702 | }, 703 | { 704 | "cell_type": "code", 705 | "execution_count": 9, 706 | "id": "6e750778", 707 | "metadata": {}, 708 | "outputs": [], 709 | "source": [ 710 | "pred_gender = pred.drop('age', axis=1).idxmax(axis=1)" 711 | ] 712 | }, 713 | { 714 | "cell_type": "markdown", 715 | "id": "e0781912", 716 | "metadata": {}, 717 | "source": [ 718 | "We measure age performance by means of Mean Absolute Error (MAE), which we calculate with [audmetric](https://github.com/audeering/audmetric). In addition, we show a distribution plot that we create with [audplot](https://github.com/audeering/audplot)." 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "execution_count": 10, 724 | "id": "45640e3b", 725 | "metadata": { 726 | "scrolled": false 727 | }, 728 | "outputs": [ 729 | { 730 | "data": { 731 | "text/plain": [ 732 | "8.348510439150802" 733 | ] 734 | }, 735 | "execution_count": 10, 736 | "metadata": {}, 737 | "output_type": "execute_result" 738 | }, 739 | { 740 | "data": { 741 | "image/png": "\n", 742 | "text/plain": [ 743 | "
" 744 | ] 745 | }, 746 | "metadata": {}, 747 | "output_type": "display_data" 748 | } 749 | ], 750 | "source": [ 751 | "import audmetric\n", 752 | "import audplot\n", 753 | "\n", 754 | "\n", 755 | "audplot.distribution(age, pred_age)\n", 756 | "audmetric.mean_absolute_error( \n", 757 | " age,\n", 758 | " pred_age,\n", 759 | ")" 760 | ] 761 | }, 762 | { 763 | "cell_type": "markdown", 764 | "id": "d747718a", 765 | "metadata": {}, 766 | "source": [ 767 | "A mean average error of about 8 years is within the range we see for the datasets in the paper (tba). From the distribution we see that some of the samples are predicted too old.\n", 768 | "\n", 769 | "To evaluate our gender predictions we calculate Unweighted Average Recall (UAR) and show a confusion matrix." 770 | ] 771 | }, 772 | { 773 | "cell_type": "code", 774 | "execution_count": 11, 775 | "id": "7978496c", 776 | "metadata": { 777 | "scrolled": false 778 | }, 779 | "outputs": [ 780 | { 781 | "data": { 782 | "text/plain": [ 783 | "96.04496489781997" 784 | ] 785 | }, 786 | "execution_count": 11, 787 | "metadata": {}, 788 | "output_type": "execute_result" 789 | }, 790 | { 791 | "data": { 792 | "image/png": "\n", 793 | "text/plain": [ 794 | "
" 795 | ] 796 | }, 797 | "metadata": {}, 798 | "output_type": "display_data" 799 | } 800 | ], 801 | "source": [ 802 | "audplot.confusion_matrix(\n", 803 | " gender, \n", 804 | " pred_gender,\n", 805 | " percentage=True,\n", 806 | " show_both=True,\n", 807 | ")\n", 808 | "audmetric.unweighted_average_recall(\n", 809 | " gender, \n", 810 | " pred_gender,\n", 811 | ") * 100" 812 | ] 813 | }, 814 | { 815 | "cell_type": "markdown", 816 | "id": "830ec788", 817 | "metadata": {}, 818 | "source": [ 819 | "The confusion matrix tells us that almost all male samples are correctly labeled, but some of the female samples are mis-classified. A likely assumption is that the model has mainly issues with the affective speech. To prove this, we repeat the evaluation on the neutral samples." 820 | ] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "execution_count": 12, 825 | "id": "25a7b833", 826 | "metadata": { 827 | "scrolled": false 828 | }, 829 | "outputs": [ 830 | { 831 | "data": { 832 | "text/plain": [ 833 | "5.940903651563427" 834 | ] 835 | }, 836 | "execution_count": 12, 837 | "metadata": {}, 838 | "output_type": "execute_result" 839 | }, 840 | { 841 | "data": { 842 | "image/png": "\n", 843 | "text/plain": [ 844 | "
" 845 | ] 846 | }, 847 | "metadata": {}, 848 | "output_type": "display_data" 849 | } 850 | ], 851 | "source": [ 852 | "mask = (emotion == 'neutral').values\n", 853 | "\n", 854 | "audplot.distribution(age[mask], pred_age[mask])\n", 855 | "audmetric.mean_absolute_error( \n", 856 | " age[mask],\n", 857 | " pred_age[mask],\n", 858 | ")" 859 | ] 860 | }, 861 | { 862 | "cell_type": "markdown", 863 | "id": "4f595ad0", 864 | "metadata": {}, 865 | "source": [ 866 | "And indeed, the error of the age prediction decreases by 2 years. And we get a perfect prediction of gender." 867 | ] 868 | }, 869 | { 870 | "cell_type": "code", 871 | "execution_count": 13, 872 | "id": "1f3375bb", 873 | "metadata": { 874 | "scrolled": false 875 | }, 876 | "outputs": [ 877 | { 878 | "data": { 879 | "text/plain": [ 880 | "100.0" 881 | ] 882 | }, 883 | "execution_count": 13, 884 | "metadata": {}, 885 | "output_type": "execute_result" 886 | }, 887 | { 888 | "data": { 889 | "image/png": "\n", 890 | "text/plain": [ 891 | "
" 892 | ] 893 | }, 894 | "metadata": {}, 895 | "output_type": "display_data" 896 | } 897 | ], 898 | "source": [ 899 | "audplot.confusion_matrix(\n", 900 | " gender[mask], \n", 901 | " pred_gender[mask],\n", 902 | " percentage=True,\n", 903 | " show_both=True,\n", 904 | ")\n", 905 | "audmetric.unweighted_average_recall(\n", 906 | " gender[mask], \n", 907 | " pred_gender[mask],\n", 908 | ") * 100" 909 | ] 910 | } 911 | ], 912 | "metadata": { 913 | "kernelspec": { 914 | "display_name": "Python 3 (ipykernel)", 915 | "language": "python", 916 | "name": "python3" 917 | }, 918 | "language_info": { 919 | "codemirror_mode": { 920 | "name": "ipython", 921 | "version": 3 922 | }, 923 | "file_extension": ".py", 924 | "mimetype": "text/x-python", 925 | "name": "python", 926 | "nbconvert_exporter": "python", 927 | "pygments_lexer": "ipython3", 928 | "version": "3.8.16" 929 | }, 930 | "toc": { 931 | "base_numbering": 1, 932 | "nav_menu": {}, 933 | "number_sections": true, 934 | "sideBar": true, 935 | "skip_h1_title": false, 936 | "title_cell": "Table of Contents", 937 | "title_sidebar": "Contents", 938 | "toc_cell": false, 939 | "toc_position": {}, 940 | "toc_section_display": true, 941 | "toc_window_display": false 942 | } 943 | }, 944 | "nbformat": 4, 945 | "nbformat_minor": 5 946 | } 947 | --------------------------------------------------------------------------------