├── .gitattributes ├── LICENSE ├── .gitignore ├── BiHDM_example.ipynb ├── README.md └── BiHDM.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yichen Tang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # PyCharm 148 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 150 | # and can be added to the global gitignore or merged into this file. For a more nuclear 151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 152 | #.idea/ 153 | -------------------------------------------------------------------------------- /BiHDM_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "from sklearn.base import clone\n", 11 | "from sklearn.model_selection import cross_val_score, StratifiedKFold\n", 12 | "\n", 13 | "from BiHDM import BiHDMClassifier" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "# Define 64 EEG channels using 10-20 standard (on a 64-channel BioSemi cap)\n", 23 | "ch_names = ['Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', \n", 24 | " 'C1', 'C3', 'C5', 'T7', 'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', \n", 25 | " 'P7', 'P9', 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'CPz', 'Fpz', \n", 26 | " 'Fp2', 'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', \n", 27 | " 'FC4', 'FC2', 'FCz', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', \n", 28 | " 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2']\n", 29 | "\n", 30 | "lh_chs = ['Fp1', 'AF7', 'AF3', 'F7', 'F5', 'F3', 'F1', 'FT7', 'FC5', 'FC3', 'FC1', \n", 31 | " 'T7', 'C5', 'C3', 'C1', 'TP7', 'CP5', 'CP3', 'CP1', 'P7', 'P5', 'P3', 'P1', \n", 32 | " 'PO7', 'PO3', 'O1']\n", 33 | "rh_chs = ['Fp2', 'AF8', 'AF4', 'F8', 'F6', 'F4', 'F2', 'FT8', 'FC6', 'FC4', 'FC2', \n", 34 | " 'T8', 'C6', 'C4', 'C2','TP8', 'CP6', 'CP4', 'CP2', 'P8', 'P6', 'P4', 'P2', \n", 35 | " 'PO8', 'PO4', 'O2']\n", 36 | "lv_chs = ['Fp1', 'AF7', 'F7', 'FT7', 'T7', 'TP7', 'P7', 'PO7', 'AF3', 'F5', 'FC5', \n", 37 | " 'C5', 'CP5', 'P5', 'O1', 'F3', 'FC3', 'C3', 'CP3', 'P3', 'PO3', 'F1', 'FC1', \n", 38 | " 'C1', 'CP1', 'P1']\n", 39 | "rv_chs = ['Fp2', 'AF8', 'F8', 'FT8', 'T8', 'TP8', 'P8', 'PO8', 'AF4', 'F6', 'FC6', \n", 40 | " 'C6', 'CP6', 'P6', 'O2', 'F4', 'FC4', 'C4', 'CP4', 'P4', 'PO4', 'F2', 'FC2', \n", 41 | " 'C2', 'CP2', 'P2']" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# Generate some data for classification\n", 51 | "X = np.ones((1000, 64, 5)) # 1000 samples x 64 channels x 5 bands per channel (delta, theta, alpha, beta, gamma)\n", 52 | "y = np.repeat([0,1], 500)\n", 53 | "\n", 54 | "# Let's simulate a frontal alpha-asymmetry for the classifier to learn from\n", 55 | "left_frontal_chs = ['Fp1', 'AF3', 'AF7', 'F1', 'F3', 'F5', 'FC3', 'FC1']\n", 56 | "X[:500,np.isin(ch_names, left_frontal_chs),2] -= 1\n", 57 | "\n", 58 | "# And let's add some gaussian noise\n", 59 | "rng = np.random.default_rng(42)\n", 60 | "X += rng.normal(scale=0.5, size=X.shape)\n", 61 | "\n", 62 | "# Reshape X to meet sklearn standard\n", 63 | "X = X.reshape(1000, -1)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "0.783\n", 76 | "[0.745 0.76 0.775 0.805 0.83 ]\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "clf = BiHDMClassifier(ch_names, lh_chs, rh_chs, lv_chs, rv_chs, \n", 82 | " d_stream=32, d_pair=32, d_global=32, d_out=16, \n", 83 | " k=6, a=0.01, pairwise_operation='subtraction', \n", 84 | " rnn_stream_kwargs={}, rnn_global_kwargs={}, \n", 85 | " loss='NLLLoss', domain_loss='NLLLoss', optimizer='SGD', lr=0.003, \n", 86 | " epochs=8, batch_size=200, loss_kwargs={}, domain_loss_kwargs={}, \n", 87 | " optimizer_kwargs=dict(momentum=0.9, weight_decay=0.95),\n", 88 | " random_state=42, use_gpu=True, verbose=False)\n", 89 | "\n", 90 | "# first let's test the model without performing the domain adversarial strategy\n", 91 | "scores = cross_val_score(clf, X, y)\n", 92 | "print(np.mean(scores))\n", 93 | "print(scores)" 94 | ] 95 | }, 96 | { 97 | "attachments": {}, 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "Then let's test the model again, but this time with the domain adversarial strategy." 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 5, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "def custom_cross_val_score(clf, X, y):\n", 111 | " cv = StratifiedKFold()\n", 112 | " scores = []\n", 113 | " for train_index, test_index in cv.split(X, y):\n", 114 | " cloned_clf = clone(clf)\n", 115 | " X_train = X[train_index]\n", 116 | " y_train = y[train_index]\n", 117 | " X_test = X[test_index]\n", 118 | " y_test = y[test_index]\n", 119 | " \n", 120 | " # provide X_test for performing the domain adversarial strategy\n", 121 | " cloned_clf.fit(X_train, y_train, X_test=X_test)\n", 122 | " score = cloned_clf.score(X_test, y_test)\n", 123 | " scores.append(score)\n", 124 | " return np.array(scores)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 6, 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "name": "stdout", 134 | "output_type": "stream", 135 | "text": [ 136 | "0.7870000000000001\n", 137 | "[0.795 0.78 0.835 0.85 0.675]\n" 138 | ] 139 | } 140 | ], 141 | "source": [ 142 | "scores_d = custom_cross_val_score(clf, X, y)\n", 143 | "print(np.mean(scores_d))\n", 144 | "print(scores_d)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [] 153 | } 154 | ], 155 | "metadata": { 156 | "kernelspec": { 157 | "display_name": "prosody_1", 158 | "language": "python", 159 | "name": "python3" 160 | }, 161 | "language_info": { 162 | "codemirror_mode": { 163 | "name": "ipython", 164 | "version": 3 165 | }, 166 | "file_extension": ".py", 167 | "mimetype": "text/x-python", 168 | "name": "python", 169 | "nbconvert_exporter": "python", 170 | "pygments_lexer": "ipython3", 171 | "version": "3.9.15" 172 | }, 173 | "orig_nbformat": 4, 174 | "vscode": { 175 | "interpreter": { 176 | "hash": "3f38051f1ddded3e326d0b827e329a10034216e79b2f401ba9b72226d19a5307" 177 | } 178 | } 179 | }, 180 | "nbformat": 4, 181 | "nbformat_minor": 2 182 | } 183 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BiHDM_pytorch 2 | An unofficial pytorch implementation of the BiHDM model proposed by Yang et al. [1] for decoding emotion from multi-channel electroencephalogram (EEG) recordings, with scikit-learn compatibility. 3 | 4 | > **Warning** 5 | > Please note this is not an official implementation, nor has been tested on the datasets used in the original studies. Due to different libraries and hyperparameters used in the implementation (and potentially implementation errors), there might be differences in the performance of this model to the ones as described in the papers. Please always examine the source code, make your own changes if necessary, and describe the actual implementation if you are using this model for an academic study. And please raise an issue if you found any implementation error in my code, thank you! 6 | 7 | ## Introduction 8 | 9 | This repository presents a pytorch implementation of the BiHDM model proposed by Yang et al. [1]. The BiHDM model effectively leverages the bi-hemispheric discrepancy features of EEG to achieve high classification accuracies in decoding emotions. 10 | 11 | The BiHDM model first obtains deep representations for electrodes on the left and right hemispheres separately, utilizing either a horizontal or a vertical stream. Within each stream, the model learns bi-hemisphere discrepancy features by performing pairwise operations on the deep representations for matching electrodes on the two hemispheres. Finally, the bi-hemisphere discrepancy features from both horizontal and vertical streams are combined to predict the emotion label for the given EEG sample (see Fig. 1 in [1]). 12 | 13 | The default hyper-parameters utilized in this implementation are based on the settings outlined in the original paper [1]. While these settings performed generally well with my own EEG datasets, tuning certain hyperparameters did lead to improved classification accuracies. When applying this implementation for your own projects, you may want to experiment with these settings for best outcomes. 14 | 15 | ## Requirements 16 | This model was coded and tested on Python 3.9 with the following libraries and versions (minor differences in versions should not affect the model outcomes): 17 | 18 | ```Python 19 | numpy >= 1.21.6 20 | scikit-learn >= 1.1.3 21 | torch == 1.13.1+cu116 22 | ``` 23 | 24 | ## Examples 25 | 26 | See "BiHDM_example.ipynb". 27 | 28 | ```Python 29 | >>> import numpy as np 30 | >>> from sklearn.base import clone 31 | >>> from sklearn.model_selection import cross_val_score, StratifiedKFold 32 | 33 | >>> from BiHDM import BiHDMClassifier 34 | 35 | >>> # Define 64 EEG channels using 10-20 standard (on a 64-channel BioSemi cap) 36 | >>> ch_names = ['Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 37 | >>> 'C1', 'C3', 'C5', 'T7', 'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 38 | >>> 'P7', 'P9', 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'CPz', 'Fpz', 39 | >>> 'Fp2', 'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', 40 | >>> 'FC4', 'FC2', 'FCz', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 41 | >>> 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2'] 42 | 43 | >>> lh_chs = ['Fp1', 'AF7', 'AF3', 'F7', 'F5', 'F3', 'F1', 'FT7', 'FC5', 'FC3', 'FC1', 44 | >>> 'T7', 'C5', 'C3', 'C1', 'TP7', 'CP5', 'CP3', 'CP1', 'P7', 'P5', 'P3', 'P1', 45 | >>> 'PO7', 'PO3', 'O1'] 46 | >>> rh_chs = ['Fp2', 'AF8', 'AF4', 'F8', 'F6', 'F4', 'F2', 'FT8', 'FC6', 'FC4', 'FC2', 47 | >>> 'T8', 'C6', 'C4', 'C2','TP8', 'CP6', 'CP4', 'CP2', 'P8', 'P6', 'P4', 'P2', 48 | >>> 'PO8', 'PO4', 'O2'] 49 | >>> lv_chs = ['Fp1', 'AF7', 'F7', 'FT7', 'T7', 'TP7', 'P7', 'PO7', 'AF3', 'F5', 'FC5', 50 | >>> 'C5', 'CP5', 'P5', 'O1', 'F3', 'FC3', 'C3', 'CP3', 'P3', 'PO3', 'F1', 'FC1', 51 | >>> 'C1', 'CP1', 'P1'] 52 | >>> rv_chs = ['Fp2', 'AF8', 'F8', 'FT8', 'T8', 'TP8', 'P8', 'PO8', 'AF4', 'F6', 'FC6', 53 | >>> 'C6', 'CP6', 'P6', 'O2', 'F4', 'FC4', 'C4', 'CP4', 'P4', 'PO4', 'F2', 'FC2', 54 | >>> 'C2', 'CP2', 'P2'] 55 | 56 | >>> # Generate some data for classification 57 | >>> X = np.ones((1000, 64, 5)) # 1000 samples x 64 channels x 5 bands per channel (delta, theta, alpha, beta, gamma) 58 | >>> y = np.repeat([0,1], 500) 59 | 60 | >>> # Let's simulate a frontal alpha-asymmetry for the classifier to learn from 61 | >>> left_frontal_chs = ['Fp1', 'AF3', 'AF7', 'F1', 'F3', 'F5', 'FC3', 'FC1'] 62 | >>> X[:500,np.isin(ch_names, left_frontal_chs),2] -= 1 63 | 64 | >>> # And let's add some gaussian noise 65 | >>> rng = np.random.default_rng(42) 66 | >>> X += rng.normal(scale=0.5, size=X.shape) 67 | 68 | >>> # Reshape X to meet sklearn standard 69 | >>> X = X.reshape(1000, -1) 70 | 71 | 72 | >>> clf = BiHDMClassifier(ch_names, lh_chs, rh_chs, lv_chs, rv_chs, 73 | >>> d_stream=32, d_pair=32, d_global=32, d_out=16, 74 | >>> k=6, a=0.01, pairwise_operation='subtraction', 75 | >>> rnn_stream_kwargs={}, rnn_global_kwargs={}, 76 | >>> loss='NLLLoss', domain_loss='NLLLoss', optimizer='SGD', lr=0.003, 77 | >>> epochs=8, batch_size=200, loss_kwargs={}, domain_loss_kwargs={}, 78 | >>> optimizer_kwargs=dict(momentum=0.9, weight_decay=0.95), 79 | >>> random_state=42, use_gpu=True, verbose=False) 80 | 81 | >>> # first let's test the model without performing the domain adversarial strategy 82 | >>> scores = cross_val_score(clf, X, y) 83 | >>> print(np.mean(scores)) 84 | 0.783 85 | >>> print(scores) 86 | [0.745 0.76 0.775 0.805 0.83 ] 87 | ``` 88 | 89 | Then let's test the model again, but this time with the domain adversarial strategy. 90 | 91 | ```Python 92 | >>> def custom_cross_val_score(clf, X, y): 93 | >>> cv = StratifiedKFold() 94 | >>> scores = [] 95 | >>> for train_index, test_index in cv.split(X, y): 96 | >>> cloned_clf = clone(clf) 97 | >>> X_train = X[train_index] 98 | >>> y_train = y[train_index] 99 | >>> X_test = X[test_index] 100 | >>> y_test = y[test_index] 101 | 102 | >>> # provide X_test for performing the domain adversarial strategy 103 | >>> cloned_clf.fit(X_train, y_train, X_test=X_test) 104 | >>> score = cloned_clf.score(X_test, y_test) 105 | >>> scores.append(score) 106 | >>> return np.array(scores) 107 | 108 | >>> scores_d = custom_cross_val_score(clf, X, y) 109 | >>> print(np.mean(scores_d)) 110 | 0.7870000000000001 111 | >>> print(scores_d) 112 | [0.795 0.78 0.835 0.85 0.675] 113 | ``` 114 | 115 | # Acknowledgements 116 | Special thanks to some partial implementation of BiHDM by https://github.com/numediart, which inspired some of my implementation. 117 | 118 | # References 119 | [1] Y. Li et al., “A Novel Bi-Hemispheric Discrepancy Model for EEG Emotion Recognition,” IEEE Trans. Cogn. Dev. Syst., vol. 13, no. 2, pp. 354–367, Jun. 2021, doi: 10.1109/TCDS.2020.2999337. 120 | -------------------------------------------------------------------------------- /BiHDM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.utils.data import DataLoader, TensorDataset 7 | from importlib import import_module 8 | 9 | from sklearn.base import BaseEstimator, ClassifierMixin 10 | from sklearn.preprocessing import LabelEncoder 11 | from sklearn.utils.validation import check_X_y, check_array, check_is_fitted 12 | 13 | 14 | class LProjector(nn.Module): 15 | """A module for the local projector layer in BiHDM. 16 | Leaky ReLU is used as the activation function to add non-linearity. 17 | 18 | Parameters 19 | ---------- 20 | n : int 21 | Number of electrodes in a stream (input RNN sequence length). 22 | d : int 23 | Number of global high level features in each node. 24 | k : int 25 | Number of nodes in the projector layer. 26 | a : float, optional (default=0.01) 27 | Slope for LeakyReLU. 28 | """ 29 | 30 | def __init__(self, n, d, k, a=0.01): 31 | super(LProjector, self).__init__() 32 | self.n = n 33 | self.d = d 34 | self.k = k 35 | self.a = a 36 | 37 | self.act_func_ = nn.LeakyReLU(a) 38 | 39 | # Weights and bias 40 | self.weight = nn.Parameter(torch.randn((n, k)), requires_grad=True) 41 | self.bias = nn.Parameter(torch.zeros((d, k)), requires_grad=True) 42 | 43 | # Initialize weights and bias 44 | nn.init.kaiming_uniform_(self.weight, a=a, mode='fan_in', nonlinearity='leaky_relu') 45 | b_bound = 1 / math.sqrt(n) 46 | nn.init.uniform_(self.bias, -b_bound, b_bound) # U(-sqrt(k), sqrt(k)) where k = 1/n 47 | 48 | def forward(self, x): 49 | ''' 50 | Forward pass of LProjector. 51 | 52 | Parameters 53 | ---------- 54 | x : torch.Tensor 55 | Torch tensor of shape s x n x d, where s is the size of sample x. 56 | 57 | Returns 58 | ------- 59 | g : torch.Tensor 60 | Torch tensor of shape s x d x k, where s is the size of sample x. 61 | ''' 62 | ws = torch.einsum('nk,snd->sdk', self.weight, x) 63 | return torch.add(self.act_func_(ws), self.bias) 64 | 65 | class GradReverse(Function): 66 | """The gradient reversal function. 67 | Forward propogation does not change input, while backward propogation 68 | reverses the gradients. 69 | 70 | Implementation followed this forum answer: 71 | https://discuss.pytorch.org/t/solved-reverse-gradients-in-backward-pass/3589/4 72 | """ 73 | 74 | @staticmethod 75 | def forward(ctx, x): 76 | return x.view_as(x) 77 | 78 | @staticmethod 79 | def backward(ctx, grad_output): 80 | return grad_output.neg() 81 | 82 | 83 | class GRL(nn.Module): 84 | def forward(self, x): 85 | return GradReverse.apply(x) 86 | 87 | 88 | class BiHDM(nn.Module): 89 | '''An implementation of the BiHDM model proposed in [1]. 90 | 91 | Parameters 92 | ---------- 93 | lh_stream : list of int 94 | Indices for selecting the left hemisphere horizontal stream. 95 | rh_stream : list of int 96 | Indices for selecting the right hemisphere horizontal stream. 97 | lv_stream : list of int 98 | Indices for selecting the left hemisphere vertical stream. 99 | rv_stream : list of int 100 | Indices for selecting the right hemisphere vertical stream. 101 | n_classes : int 102 | Number of classes for classification. 103 | d_input: int 104 | Number of features for each electrode's raw representation. 105 | d_stream : int, optional (default=32) 106 | Number of features for each electrode's deep representation (d_l in paper [1]). 107 | d_pair: int , optional (default=32) 108 | Number of features for each electrode's deep representation after the pairwise 109 | operation (d_p1, d_p2, or d_p3 in paper [1]). 110 | d_global: int, optional (default=32) 111 | Number of global high level features (d_g in paper [1]). 112 | d_out: int, optional (default=16) 113 | Number of output features (d_o in paper [1]). 114 | k: int, optional (default=6) 115 | Number of nodes for global high level features (K in paper [1]). 116 | a : float, optional (default=0.01) 117 | Slope for LeakyReLU in high level feature projector. 118 | pairwise_operation: str or custom function (default='subtraction') 119 | Pairwise operation for the two hemispheres' electrode deep representation streams. 120 | If string, acceptable operations are 'subtraction', 'addition', 'division', and 121 | 'inner_product'. If custom function, it should take two torch.Tensor objects (s_left 122 | and s_right of sizes L x N x d_stream) and return one torch.Tensor of size L x N x d_pair. 123 | Where L is the number of electrodes in a stream, N is the batch size. 124 | output_domain: bool, optional (default=True) 125 | If the domain adversarial strategy should be performed. If set to true, an extra domain 126 | discrimination layer will be added in parallel to the class discrimination layer. The domain 127 | predictions will be returned as the second element of a tuple together with the class predictions 128 | in the forward method. 129 | rnn_stream_kwargs: dict, optional (default={}) 130 | kwargs to feed into RNNs extracting electrodes' deep features. 131 | rnn_global_kwargs: dict, optional (default={}) 132 | kwargs to feed into RNNs extracting global high level features. 133 | 134 | Notes 135 | ----- 136 | [1] Y. Li et al., “A Novel Bi-Hemispheric Discrepancy Model for EEG Emotion Recognition,” 137 | IEEE Trans. Cogn. Dev. Syst., vol. 13, no. 2, pp. 354–367, Jun. 2021. 138 | ''' 139 | 140 | def __init__(self, lh_stream, rh_stream, lv_stream, rv_stream, n_classes, 141 | d_input, d_stream=32, d_pair=32, d_global=32, d_out=16, k=6, a=0.01, 142 | pairwise_operation='subtraction', output_domain=True, 143 | rnn_stream_kwargs={}, rnn_global_kwargs={}): 144 | super(BiHDM, self).__init__() 145 | 146 | # Store the inputs as instance variables. 147 | self.lh_stream = lh_stream 148 | self.rh_stream = rh_stream 149 | self.lv_stream = lv_stream 150 | self.rv_stream = rv_stream 151 | self.d_input = d_input 152 | self.d_stream = d_stream 153 | self.d_pair = d_pair 154 | self.d_global = d_global 155 | self.d_out = d_out 156 | self.k = k 157 | self.a = a 158 | self.n_classes = n_classes 159 | self.pairwise_operation = pairwise_operation 160 | self.output_domain = output_domain 161 | self.rnn_stream_kwargs = rnn_stream_kwargs 162 | self.rnn_global_kwargs = rnn_global_kwargs 163 | 164 | # Define the RNNs for each stream. 165 | self.rnn_lh_ = nn.RNN(d_input, d_stream, batch_first=False, **rnn_stream_kwargs) 166 | self.rnn_rh_ = nn.RNN(d_input, d_stream, batch_first=False, **rnn_stream_kwargs) 167 | self.rnn_lv_ = nn.RNN(d_input, d_stream, batch_first=False, **rnn_stream_kwargs) 168 | self.rnn_rv_ = nn.RNN(d_input, d_stream, batch_first=False, **rnn_stream_kwargs) 169 | 170 | # Define the pairwise operation to use based on the input argument. 171 | if pairwise_operation == 'subtraction': 172 | self.pair_ = self.pairwise_subtraction 173 | elif pairwise_operation == 'addition': 174 | self.pair_ = self.pairwise_addition 175 | elif pairwise_operation == 'division': 176 | self.pair_ = self.pairwise_division 177 | elif pairwise_operation == 'inner_product': 178 | self.pair_ = self.pairwise_inner 179 | else: 180 | # Use a custom pairwise operation if one is provided. 181 | self.pair_ = self.pairwise_operation 182 | 183 | # Define the RNNs for the global representations of the two paired streams. 184 | self.rnn_hg_ = nn.RNN(d_pair, d_global, batch_first=False, **rnn_global_kwargs) 185 | self.rnn_vg_ = nn.RNN(d_pair, d_global, batch_first=False, **rnn_global_kwargs) 186 | 187 | # Define the LProjector instances for the two streams. 188 | self.proj_h_ = LProjector(len(lh_stream), d_global, k, a) 189 | self.proj_v_ = LProjector(len(lh_stream), d_global, k, a) 190 | 191 | # Define the learnable weight matrices for the final linear layers. 192 | self.map_h_ = nn.Parameter(torch.randn((d_out, d_global)), requires_grad=True) 193 | self.map_v_ = nn.Parameter(torch.randn((d_out, d_global)), requires_grad=True) 194 | 195 | # Define the output layers. 196 | self.out_class_ = nn.Sequential( 197 | nn.Linear(d_out * k, n_classes, bias=True), 198 | nn.LogSoftmax(dim=-1) 199 | ) 200 | if self.output_domain: 201 | self.out_domain_ = nn.Sequential( 202 | GRL(), 203 | nn.Linear(d_out * k, 2, bias=True), 204 | nn.LogSoftmax(dim=-1) 205 | ) 206 | 207 | # Initialize the weights. 208 | with torch.no_grad(): 209 | self.init_weights() 210 | 211 | def init_weights(self): 212 | """Initialize weights of the model. 213 | 214 | This method initializes the RNN weights with Xavier uniform distribution, 215 | and initializes the map weights and output linear weights with Xavier uniform 216 | distribution with gain=1. 217 | """ 218 | # init RNN weights with xavier uniform 219 | def rnn_init_weights(m): 220 | if type(m) == nn.RNN: 221 | for ws in m._all_weights: 222 | for w in ws: 223 | if 'weight' in w: 224 | nn.init.xavier_uniform_(getattr(m, w)) 225 | self.apply(rnn_init_weights) 226 | 227 | # LProjectors were initialised on construction 228 | 229 | # init maps with xavier uniform and gain=1 230 | nn.init.xavier_uniform_(self.map_h_) 231 | nn.init.xavier_uniform_(self.map_v_) 232 | 233 | # init output linear weights with xavier uniform and gain=1 234 | nn.init.xavier_uniform_(self.out_class_[0].weight) 235 | if self.output_domain: 236 | nn.init.xavier_uniform_(self.out_domain_[1].weight) 237 | 238 | def pairwise_subtraction(self, sl, sr): 239 | return sl - sr 240 | 241 | def pairwise_addition(self, sl, sr): 242 | return sl + sr 243 | 244 | def pairwise_division(self, sl, sr): 245 | return sl / sr 246 | 247 | def pairwise_inner(self, sl, sr): 248 | '''Column-wise inner product''' 249 | return torch.einsum('lnd,lnd->ln', sl, sr)[:,:,None] 250 | 251 | def forward(self, x): 252 | ''' 253 | Compute the forward pass of the BiHDM model. 254 | 255 | Parameters 256 | ---------- 257 | x : torch.Tensor of shape n_sample x n_channels x d_input 258 | The input tensor. 259 | 260 | Returns 261 | ------- 262 | torch.Tensor or tuple 263 | If self.output_domain is set to False, a tensor of shape (n_sample, n_classes) 264 | representing the class probabilities will be returned. Otherwise a tuple of size 2 265 | will be returned where the first element contains the class probabilities, and the 266 | second element contains another tensor of shape (n_sample, 2) representing the domain 267 | probabilities. 268 | ''' 269 | # electrode deep representation (len(stream) x n_sample x d_stream) 270 | lhs, _ = self.rnn_lh_(x[:,self.lh_stream].permute(1,0,2)) 271 | rhs, _ = self.rnn_rh_(x[:,self.rh_stream].permute(1,0,2)) 272 | lvs, _ = self.rnn_lv_(x[:,self.lv_stream].permute(1,0,2)) 273 | rvs, _ = self.rnn_rv_(x[:,self.rv_stream].permute(1,0,2)) 274 | 275 | # pairwise operation (len(stream) x n_sample x d_pair) 276 | ph = self.pair_(lhs, rhs) 277 | pv = self.pair_(lhs, rhs) 278 | 279 | # high level features (len(stream) x n_sample x d_global) 280 | gh, _ = self.rnn_hg_(ph) 281 | gv, _ = self.rnn_vg_(pv) 282 | 283 | # project high level features (n_sample x d_global x k) 284 | gh = self.proj_h_(gh.permute(1,0,2)) 285 | gv = self.proj_v_(gv.permute(1,0,2)) 286 | 287 | # map and summarise (n_sample x d_out x k) 288 | gh = torch.einsum('og,sgk->sok', self.map_h_, gh) 289 | gv = torch.einsum('og,sgk->sok', self.map_v_, gv) 290 | hv = gh + gv 291 | hv = hv.flatten(start_dim=1) 292 | 293 | if self.output_domain: 294 | return self.out_class_(hv), self.out_domain_(hv) 295 | else: 296 | return self.out_class_(hv) 297 | 298 | 299 | class BiHDMClassifier(BaseEstimator, ClassifierMixin): 300 | """ 301 | BiHDMClassifier is a classification algorithm that uses BiHDM (Bivariate Hierarchical 302 | Dirichlet Models) to extract relevant information from multivariate time-series EEG data. 303 | 304 | Parameters 305 | ---------- 306 | ch_names : list of str 307 | List of channel names in the EEG data. 308 | lh_chs : list of str 309 | List of channel names in the left hemisphere horizontal stream. 310 | rh_chs : list of str 311 | List of channel names in the right hemisphere horizontal stream. 312 | lv_chs : list of str 313 | List of channel names in the left hemisphere vertical stream. 314 | rv_chs : list of str 315 | List of channel names in the right hemisphere vertical stream. 316 | d_stream : int, optional (default=32) 317 | Number of features for each electrode's deep representation (d_l in paper [1]). 318 | d_pair: int , optional (default=32) 319 | Number of features for each electrode's deep representation after the pairwise 320 | operation (d_p1, d_p2, or d_p3 in paper [1]). 321 | d_global: int, optional (default=32) 322 | Number of global high level features (d_g in paper [1]). 323 | d_out: int, optional (default=16) 324 | Number of output features (d_o in paper [1]). 325 | k: int, optional (default=6) 326 | Number of nodes for global high level features (K in paper [1]). 327 | a : float, optional (default=0.01) 328 | Slope for LeakyReLU in high level feature projector. 329 | pairwise_operation : str, optional (default='subtraction') 330 | Operation used to compute pairwise interactions (see BiHDM for details). 331 | rnn_stream_kwargs : dict, optional (default={}) 332 | Keyword arguments for the stream RNN. 333 | rnn_global_kwargs : dict, optional (default={}) 334 | Keyword arguments for the global RNN. 335 | loss : str or nn.Module class - custom loss, optional (default='NLLLoss') 336 | Type of loss function for the classifier. If is a string, it must be a string exactly 337 | equal to the name of a loss function in torch.nn module (e.g., 'MSELoss', 338 | 'CrossEntropyLoss', etc.), as you are importing the loss function. See `torch.nn` for 339 | available loss functions. If is a custom loss, it must be a nn.Module class like NLLLoss. 340 | domain_loss : str or nn.Module class - custom loss, optional (default='NLLLoss') 341 | Type of loss function for the domain discriminator. If is a string, it must be a string 342 | exactly equal to the name of a loss function in torch.nn module (e.g., 'MSELoss', 343 | 'CrossEntropyLoss', etc.), as you are importing the loss function. See `torch.nn` for 344 | available loss functions. If is a custom loss, it must be a nn.Module class like NLLLoss. 345 | optimizer : str or optim.Optimizer class - custom optimizer, optional (default='SGD') 346 | Type of optimizer to use. If is a string, it must be a string exactly equal to the name 347 | of an optimizer in torch.optim module (e.g., 'SGD', 'Adam', etc.), as you are importing 348 | the optimizer function. See `torch.optim` for available optimizers. If is a custom 349 | optimizer, it must be an optim.Optimizer class like SGD. 350 | lr : float, optional (default=0.003) 351 | Learning rate. 352 | epochs : int, optional (default=10) 353 | Number of epochs. 354 | batch_size : int, optional (default=200) 355 | Batch size for training. 356 | loss_kwargs : dict, optional (default={}) 357 | Keyword arguments for the loss function. 358 | domain_loss_kwargs : dict, optional (default={}) 359 | Keyword arguments for the domain loss function. 360 | optimizer_kwargs : dict, optional (default={'momentum': 0.9, 'weight_decay': 0.95}) 361 | Keyword arguments for the optimizer. 362 | random_state : int, optional (default=42) 363 | Seed to ensure reproducibility. 364 | use_gpu : bool, optional (default=True) 365 | Whether to use GPU acceleration. 366 | verbose : bool, optional (default=True) 367 | Whether to print progress messages. 368 | 369 | Attributes 370 | ---------- 371 | n_channels_ : int 372 | Number of channels in the EEG data. 373 | n_features_in_ : int 374 | Number of input features. 375 | n_features_per_ch_ : int 376 | Number of features per channel. 377 | n_classes_ : int 378 | Number of classes in the target. 379 | le_ : LabelEncoder 380 | LabelEncoder object. 381 | device_ : torch.device 382 | PyTorch device. 383 | lh_stream_ : list of int 384 | Indices of the channels in the left hemisphere horizontal stream. 385 | rh_stream_ : list of int 386 | Indices of the channels in the right hemisphere horizontal stream. 387 | lv_stream_ : list of int 388 | Indices of the channels in the left hemisphere vertical stream. 389 | rv_stream_ : list of int 390 | Indices of the channels in the right hemisphere vertical stream. 391 | optimizer_ : torch.optim.Optimizer 392 | Optimizer used for training. 393 | loss_fn_ : nn.Module 394 | Loss function used for training. 395 | classes_ : ndarray of shape (n_classes,) 396 | Unique classes in the target variable. 397 | domain_discrimination_ : bool 398 | Either to perform domain adversarial strategy or not, inferred from the inputs to 399 | the fit method. 400 | """ 401 | 402 | def __init__(self, ch_names, lh_chs, rh_chs, lv_chs, rv_chs, 403 | d_stream=32, d_pair=32, d_global=32, d_out=16, 404 | k=6, a=0.01, pairwise_operation='subtraction', 405 | rnn_stream_kwargs={}, rnn_global_kwargs={}, 406 | loss='NLLLoss', domain_loss='NLLLoss', optimizer='SGD', lr=0.003, 407 | epochs=10, batch_size=200, loss_kwargs={}, domain_loss_kwargs={}, 408 | optimizer_kwargs=dict(momentum=0.9, weight_decay=0.95), 409 | random_state=42, use_gpu=True, verbose=True): 410 | 411 | self.ch_names = ch_names 412 | self.lh_chs = lh_chs 413 | self.rh_chs = rh_chs 414 | self.lv_chs = lv_chs 415 | self.rv_chs = rv_chs 416 | 417 | self.d_stream = d_stream 418 | self.d_pair = d_pair 419 | self.d_global = d_global 420 | self.d_out = d_out 421 | self.k = k 422 | self.a = a 423 | self.pairwise_operation = pairwise_operation 424 | 425 | self.rnn_stream_kwargs = rnn_stream_kwargs 426 | self.rnn_global_kwargs = rnn_global_kwargs 427 | 428 | self.loss = loss 429 | self.domain_loss = domain_loss 430 | self.optimizer = optimizer 431 | self.lr = lr 432 | self.epochs = epochs 433 | self.batch_size = batch_size 434 | self.loss_kwargs = loss_kwargs 435 | self.domain_loss_kwargs = domain_loss_kwargs 436 | self.optimizer_kwargs = optimizer_kwargs 437 | 438 | self.random_state = random_state 439 | self.use_gpu = use_gpu 440 | self.verbose = verbose 441 | 442 | # selecting the indices for corresponding streams 443 | self.n_channels_ = len(ch_names) 444 | self.lh_stream_ = [list(ch_names).index(ch) for ch in lh_chs] 445 | self.rh_stream_ = [list(ch_names).index(ch) for ch in rh_chs] 446 | self.lv_stream_ = [list(ch_names).index(ch) for ch in lv_chs] 447 | self.rv_stream_ = [list(ch_names).index(ch) for ch in rv_chs] 448 | 449 | if torch.cuda.is_available() and use_gpu==True: 450 | dev = "cuda:0" 451 | else: 452 | dev = "cpu" 453 | self.device_ = torch.device(dev) 454 | 455 | def fit(self, X, y, X_test=None): 456 | ''' 457 | Fit the BiHDMClassifier to the training data. 458 | 459 | Parameters 460 | ---------- 461 | X : array-like of shape (n_samples, n_features) 462 | X will be internally reshaped to (n_samples, n_channels, n_features_per_channel) 463 | by calling numpy.reshape using C-like index order. 464 | y : array-like of shape (n_samples,) 465 | Target variable. 466 | X_test : array-like of shape (n_test_samples, n_features), optional (default=None) 467 | The test data for performing the domain adversarial strategy. If not provided, the 468 | classifier will be trained only with the class discriminator but not the domain 469 | discriminator. X_test will be internally reshaped to (n_test_samples, n_channels, 470 | n_features_per_channel) by calling numpy.reshape using C-like index order. 471 | 472 | Returns 473 | ------- 474 | self : BiHDMClassifier 475 | The trained classifier. 476 | ''' 477 | # Check that X and y have correct shape 478 | X, y = check_X_y(X, y) 479 | if X_test is not None: 480 | self.domain_discrimination_ = True 481 | X_test = check_array(X_test) 482 | else: 483 | self.domain_discrimination_ = False 484 | 485 | # Get dimensions of the input data 486 | self.n_features_in_ = X.shape[1] 487 | self.n_features_per_ch_ = int(self.n_features_in_/self.n_channels_) 488 | assert self.n_features_per_ch_ * self.n_channels_ == self.n_features_in_, \ 489 | f"Number of features ({self.n_features_in_}) could not be equally divided by the channels ({self.n_channels_})." 490 | self.n_classes_ = np.unique(y).shape[0] 491 | 492 | if not self.random_state is None: 493 | torch.manual_seed(self.random_state) 494 | 495 | # Encode the target variable 496 | self.le_ = LabelEncoder() 497 | self.le_.fit(y) 498 | 499 | # Reshape X and y and cast into torch tensors for training the BiHDM module 500 | X_ = np.reshape(X, [X.shape[0], self.n_channels_, self.n_features_per_ch_], order='C') 501 | X_ = torch.as_tensor(X_, dtype=torch.float).to(self.device_) 502 | y_ = torch.as_tensor(self.le_.transform(y), dtype=torch.int64).to(self.device_) 503 | if self.domain_discrimination_: 504 | X_test_ = np.reshape(X_test, [X_test.shape[0], self.n_channels_, self.n_features_per_ch_], order='C') 505 | X_test_ = torch.as_tensor(X_test_, dtype=torch.float).to(self.device_) 506 | n_tests = X_test_.size()[0] 507 | 508 | # Construct BiHDM 509 | self.bihdm_ = BiHDM(self.lh_stream_, self.rh_stream_, self.lv_stream_, self.rv_stream_, 510 | n_classes=self.n_classes_, d_input=self.n_features_per_ch_, 511 | d_stream=self.d_stream, d_pair=self.d_pair, 512 | d_global=self.d_global, d_out=self.d_out, k=self.k, a=self.a, 513 | pairwise_operation=self.pairwise_operation, output_domain=self.domain_discrimination_, 514 | rnn_stream_kwargs=self.rnn_stream_kwargs, 515 | rnn_global_kwargs=self.rnn_global_kwargs) 516 | self.bihdm_.to(self.device_) 517 | 518 | # Setup training steps 519 | dataset = TensorDataset(X_, y_) 520 | loader = DataLoader(dataset, batch_size=self.batch_size) 521 | loss_fn = getattr(import_module('torch.nn'), self.loss)(**self.loss_kwargs) \ 522 | if type(self.loss) == str else self.loss(**self.loss_kwargs) 523 | loss_domain = getattr(import_module('torch.nn'), self.domain_loss)(**self.domain_loss_kwargs) \ 524 | if type(self.domain_loss) == str else self.domain_loss(**self.domain_loss_kwargs) 525 | optimizer = getattr(import_module('torch.optim'), self.optimizer) if type(self.optimizer) == str else self.optimizer 526 | optimizer = optimizer(self.bihdm_.parameters(), lr=self.lr, **self.optimizer_kwargs) 527 | 528 | # Iterate through epochs to train BiHDM 529 | self.bihdm_.train(True) 530 | for ep in range(self.epochs): 531 | running_loss = 0. 532 | 533 | for i, (batch, labels) in enumerate(loader): 534 | optimizer.zero_grad() 535 | if self.domain_discrimination_: 536 | outputs, domains_train = self.bihdm_.forward(batch) 537 | _, domains_test = self.bihdm_.forward(X_test_) 538 | domains = torch.cat([domains_train, domains_test], dim=0) 539 | n_trains = labels.size()[0] 540 | domain_labels = torch.zeros(n_trains + n_tests, dtype=torch.int64).to(self.device_) 541 | domain_labels[n_trains:] = 1 542 | loss = loss_fn(outputs, labels) + loss_domain(domains, domain_labels) 543 | else: 544 | outputs = self.bihdm_.forward(batch) 545 | loss = loss_fn(outputs, labels) 546 | loss.backward() 547 | optimizer.step() 548 | running_loss += loss.item() 549 | avg_loss = running_loss / len(loader) 550 | 551 | if self.verbose: 552 | print(f'Epoch {ep}: loss_train={avg_loss}') 553 | self.bihdm_.train(False) 554 | 555 | return self 556 | 557 | def predict(self, X): 558 | """Predict the class labels for the given input samples. 559 | 560 | Parameters 561 | ---------- 562 | X : array-like of shape (n_samples, n_features) 563 | The input samples. 564 | 565 | Returns 566 | ------- 567 | y_pred : array-like of shape (n_samples,) 568 | The predicted class labels. 569 | """ 570 | 571 | # Check if fit has been called 572 | check_is_fitted(self) 573 | 574 | # Input validation 575 | X = check_array(X) 576 | X = np.reshape(X, [X.shape[0], self.n_channels_, self.n_features_per_ch_], order='C') 577 | X = torch.as_tensor(X, dtype=torch.float).to(self.device_) 578 | 579 | pred = self.bihdm_(X) 580 | pred = pred[0] if self.domain_discrimination_ else pred 581 | return self.le_.inverse_transform(torch.argmax(pred, dim=-1).to('cpu').detach().numpy()) 582 | 583 | def predict_proba(self, X): 584 | """Predict class probabilities for the given input samples. 585 | 586 | Parameters 587 | ---------- 588 | X : array-like of shape (n_samples, n_features) 589 | The input samples. 590 | 591 | Returns 592 | ------- 593 | y_proba : array-like of shape (n_samples, n_classes) 594 | The class probabilities of the input samples. 595 | """ 596 | # Check if fit has been called 597 | check_is_fitted(self) 598 | 599 | # Input validation 600 | X = check_array(X) 601 | X = np.reshape(X, [X.shape[0], self.n_channels_, self.n_features_per_ch_], order='C') 602 | X = torch.as_tensor(X, dtype=torch.float).to(self.device_) 603 | 604 | pred = self.bihdm_(X) 605 | pred = pred[0] if self.domain_discrimination_ else pred 606 | return torch.exp(pred).to('cpu').detach().numpy() 607 | --------------------------------------------------------------------------------