├── .gitignore ├── LICENSE ├── README.md ├── _local ├── cleanup.sh ├── pull_git.sh └── push_git.sh ├── examples ├── butterworth-filter.ipynb ├── load-openbmi.ipynb └── multitask-learning-model.ipynb ├── requirements.txt └── torchsignal ├── __init__.py ├── cross_decomposition ├── CCA.py ├── TRCA.py ├── __init__.py └── reference_frequencies.py ├── datasets ├── __init__.py ├── dataset.py ├── generate.py ├── hsssvep.py ├── multiplesubjects.py ├── openbmi.py └── utils.py ├── filter ├── __init__.py ├── butterworth.py └── channels.py ├── model ├── CompactEEGNet.py ├── MIEEGNet.py ├── MultitaskSSVEP.py ├── MultitaskSSVEPClassifier.py ├── Performer.py ├── README.md ├── WaveNet.py ├── __init__.py └── common │ ├── __init__.py │ ├── conv.py │ └── utils.py ├── trainer └── multitask.py └── transform ├── __init__.py ├── fft.py ├── segment.py └── spectrogram.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # extras 132 | .DS_Store 133 | /_data 134 | /_tmp_models 135 | /_dev 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, Hong Jing (Jingles) 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchsignal: a signal processing library for PyTorch 2 | 3 | The torchsignal package consists of datasets, model architectures, and common signal processing functions before applying on PyTorch. A toolbox for data manipulation and transformation for signal processing. 4 | 5 | ## Installation 6 | 7 | Currently, this has not been released. Use Git or checkout with SVN, and install the dependencies: 8 | 9 | ``` 10 | git clone https://github.com/jinglescode/torchsignal.git 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ### Dependencies 15 | 16 | See [requirements.txt](https://github.com/jinglescode/torchsignal/tree/master/requirements.txt). 17 | 18 | ## Usage 19 | 20 | See the [examples folder](https://github.com/jinglescode/torchsignal/tree/master/examples). 21 | 22 | ## Models 23 | 24 | These are the [models](https://github.com/jinglescode/torchsignal/tree/master/torchsignal/model) available in this repo: 25 | - Multitask Model 26 | - EEGNet 27 | - Performer 28 | - WaveNet 29 | 30 | See [list of models and usage](https://github.com/jinglescode/torchsignal/tree/master/torchsignal/model). 31 | 32 | ## API Reference 33 | 34 | Work in progress. Meanwhile see the [examples folder](https://github.com/jinglescode/torchsignal/tree/master/examples). 35 | 36 | ## Tutorials 37 | 38 | We aim to bridge the gap for anyone who are new signals processings to get started. 39 | 40 | - [What Are Signals?: An Introduction to Signals](https://github.com/jinglescode/torchsignal/wiki/What-Are-Signals%3F) 41 | - [Generate Signals](https://github.com/jinglescode/torchsignal/wiki/Generate-Signals) 42 | - [Butterworth Filter](https://github.com/jinglescode/torchsignal/wiki/Butterworth-Filter) 43 | 44 | ## Contributing Guidelines 45 | 46 | Please let us know if you encounter a bug by filing an [issue](https://github.com/jinglescode/torchsignal/issues). 47 | 48 | Seeking for collaborators to contribute new features, utility functions, bug fixes, and documentation. Currently, I am working on this alone. If you are working on signal processing or brain-computer interface, and keen to build a high-quality package to apply PyTorch to the signal processing domain, reach out to me via [various channels](https://jinglescode.github.io/). 49 | 50 | ## Disclaimer on Datasets 51 | 52 | We do not host or distribute these datasets, vouch for their quality or fairness, or claim that you have license to use the dataset. It is your responsibility to determine whether you have permission to use the dataset under the dataset's license. 53 | 54 | If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community! 55 | -------------------------------------------------------------------------------- /_local/cleanup.sh: -------------------------------------------------------------------------------- 1 | find . -name '.DS_Store' -type f -delete 2 | find . -name 'Icon?' -type f -delete 3 | find . -name 'package-lock.json' -type f -delete 4 | find . -name "__pycache__" -type d -exec rm -r "{}" \; 5 | -------------------------------------------------------------------------------- /_local/pull_git.sh: -------------------------------------------------------------------------------- 1 | git pull https://github.com/jinglescode/torchsignal.git master -------------------------------------------------------------------------------- /_local/push_git.sh: -------------------------------------------------------------------------------- 1 | git add -A 2 | git commit -m "$1" 3 | git push https://github.com/jinglescode/torchsignal.git master 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4.0 2 | numpy 3 | scipy 4 | matplotlib 5 | sklearn -------------------------------------------------------------------------------- /torchsignal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinglescode/torchsignal/6172bc2b18eeafa9464cfba678e9c02ea4ed5e2a/torchsignal/__init__.py -------------------------------------------------------------------------------- /torchsignal/cross_decomposition/CCA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from sklearn.cross_decomposition import CCA 4 | from sklearn.metrics import confusion_matrix 5 | import functools 6 | 7 | 8 | def find_correlation_cca_method1(signal, reference_signals, n_components=2): 9 | r""" 10 | Perform canonical correlation analysis (CCA) 11 | Reference: https://github.com/aaravindravi/Brain-computer-interfaces/blob/master/notebook_12_class_cca.ipynb 12 | 13 | Args: 14 | signal : ndarray, shape (channel,time) 15 | Input signal in time domain 16 | reference_signals : ndarray, shape (len(flick_freq),2*num_harmonics,time) 17 | Required sinusoidal reference templates corresponding to the flicker frequency for SSVEP classification 18 | n_components : int, default: 2 19 | number of components to keep (for sklearn.cross_decomposition.CCA) 20 | Returns: 21 | result : array, size: len(flick_freq) 22 | Probability for each reference signals 23 | Dependencies: 24 | CCA : sklearn.cross_decomposition.CCA 25 | np : numpy package 26 | """ 27 | 28 | cca = CCA(n_components) 29 | corr = np.zeros(n_components) 30 | result = np.zeros(reference_signals.shape[0]) 31 | for freq_idx in range(0, reference_signals.shape[0]): 32 | cca_x = signal.T 33 | cca_y = np.squeeze(reference_signals[freq_idx, :, :]).T 34 | cca.fit(cca_x, cca_y) 35 | a, b = cca.transform(cca_x, cca_y) 36 | for ind_val in range(0, n_components): 37 | corr[ind_val] = np.corrcoef(a[:, ind_val], b[:, ind_val])[0, 1] 38 | result[freq_idx] = np.max(corr) 39 | return result 40 | 41 | 42 | def calculate_cca(dat_x, dat_y, time_axis=-2): 43 | r""" 44 | Calculate the Canonical Correlation Analysis (CCA). 45 | This method calculates the canonical correlation coefficient and 46 | corresponding weights which maximize a correlation coefficient 47 | between linear combinations of the two specified multivariable 48 | signals. 49 | Reference: https://github.com/venthur/wyrm/blob/master/wyrm/processing.py 50 | Reference: http://en.wikipedia.org/wiki/Canonical_correlation 51 | 52 | Args: 53 | dat_x : continuous Data object 54 | these data should have the same length on the time axis. 55 | dat_y : continuous Data object 56 | these data should have the same length on the time axis. 57 | time_axis : int, optional 58 | the index of the time axis in ``dat_x`` and ``dat_y``. 59 | Returns: 60 | rho : float 61 | the canonical correlation coefficient. 62 | w_x, w_y : 1d array 63 | the weights for mapping from the specified multivariable signals 64 | to canonical variables. 65 | Raises: 66 | AssertionError : 67 | If: 68 | * ``dat_x`` and ``dat_y`` is not continuous Data object 69 | * the length of ``dat_x`` and ``dat_y`` is different on the 70 | ``time_axis`` 71 | Dependencies: 72 | functools : functools package 73 | np : numpy package 74 | """ 75 | 76 | assert (len(dat_x.data.shape) == len(dat_y.data.shape) == 2 and 77 | dat_x.data.shape[time_axis] == dat_y.data.shape[time_axis]) 78 | 79 | if time_axis == 0 or time_axis == -2: 80 | x = dat_x.copy() 81 | y = dat_y.copy() 82 | else: 83 | x = dat_x.T.copy() 84 | y = dat_y.T.copy() 85 | 86 | # calculate covariances and it's inverses 87 | x -= x.mean(axis=0) 88 | y -= y.mean(axis=0) 89 | n = x.shape[0] 90 | c_xx = np.dot(x.T, x) / n 91 | c_yy = np.dot(y.T, y) / n 92 | c_xy = np.dot(x.T, y) / n 93 | c_yx = np.dot(y.T, x) / n 94 | ic_xx = np.linalg.pinv(c_xx) 95 | ic_yy = np.linalg.pinv(c_yy) 96 | # calculate w_x 97 | w, v = np.linalg.eig(functools.reduce(np.dot, [ic_xx, c_xy, ic_yy, c_yx])) 98 | w_x = v[:, np.argmax(w)].real 99 | w_x = w_x / np.sqrt(functools.reduce(np.dot, [w_x.T, c_xx, w_x])) 100 | # calculate w_y 101 | w, v = np.linalg.eig(functools.reduce(np.dot, [ic_yy, c_yx, ic_xx, c_xy])) 102 | w_y = v[:, np.argmax(w)].real 103 | w_y = w_y / np.sqrt(functools.reduce(np.dot, [w_y.T, c_yy, w_y])) 104 | # calculate rho 105 | rho = abs(functools.reduce(np.dot, [w_x.T, c_xy, w_y])) 106 | return rho, w_x, w_y 107 | 108 | 109 | def find_correlation_cca_method2(signal, reference_signals): 110 | r""" 111 | Perform canonical correlation analysis (CCA) 112 | 113 | Args: 114 | signal : ndarray, shape (channel,time) 115 | Input signal in time domain 116 | reference_signals : ndarray, shape (len(flick_freq),2*num_harmonics,time) 117 | Required sinusoidal reference templates corresponding to the flicker frequency for SSVEP classification 118 | Returns: 119 | result : array, size: len(flick_freq) 120 | Probability for each reference signals 121 | Dependencies: 122 | np : numpy package 123 | calculate_cca : function 124 | """ 125 | 126 | result = np.zeros(reference_signals.shape[0]) 127 | for freq_idx in range(0, reference_signals.shape[0]): 128 | dat_y = np.squeeze(reference_signals[freq_idx, :, :]).T 129 | rho, w_x, w_y = calculate_cca(signal.T, dat_y) 130 | result[freq_idx] = rho 131 | return result 132 | 133 | 134 | def perform_cca(signal, reference_frequencies, labels=None): 135 | r""" 136 | Perform canonical correlation analysis (CCA) 137 | 138 | Args: 139 | signal : ndarray, shape (trial,channel,time) or (trial,channel,segment,time) 140 | Input signal in time domain 141 | reference_frequencies : ndarray, shape (len(flick_freq),2*num_harmonics,time) 142 | Required sinusoidal reference templates corresponding to the flicker frequency for SSVEP classification 143 | labels : ndarray shape (classes,) 144 | True labels of `signal`. Index of the classes must be match the sequence of `reference_frequencies` 145 | Returns: 146 | predicted_class : ndarray, size: (classes,) 147 | Predicted classes according to reference_frequencies 148 | accuracy : double 149 | If `labels` are given, `accuracy` denote classification accuracy 150 | Dependencies: 151 | confusion_matrix : sklearn.metrics.confusion_matrix 152 | find_correlation_cca_method1 : function 153 | find_correlation_cca_method2 : function 154 | """ 155 | 156 | assert (len(signal.shape) == 3 or len(signal.shape) == 4), "signal shape must be 3 or 4 dimension" 157 | 158 | actual_class = [] 159 | predicted_class = [] 160 | accuracy = None 161 | 162 | for trial in range(0, signal.shape[0]): 163 | 164 | if len(signal.shape) == 3: 165 | if labels is not None: 166 | actual_class.append(labels[trial]) 167 | tmp_signal = signal[trial, :, :] 168 | 169 | result = find_correlation_cca_method2(tmp_signal, reference_frequencies) 170 | predicted_class.append(np.argmax(result)) 171 | 172 | if len(signal.shape) == 4: 173 | for segment in range(0, signal.shape[2]): 174 | 175 | if labels is not None: 176 | actual_class.append(labels[trial]) 177 | tmp_signal = signal[trial, :, segment, :] 178 | 179 | result = find_correlation_cca_method2(tmp_signal, reference_frequencies) 180 | predicted_class.append(np.argmax(result)) 181 | 182 | actual_class = np.array(actual_class) 183 | predicted_class = np.array(predicted_class) 184 | 185 | if labels is not None: 186 | # creating a confusion matrix of true versus predicted classification labels 187 | c_mat = confusion_matrix(actual_class, predicted_class) 188 | # computing the accuracy from the confusion matrix 189 | accuracy = np.divide(np.trace(c_mat), np.sum(np.sum(c_mat))) 190 | 191 | return predicted_class, accuracy 192 | -------------------------------------------------------------------------------- /torchsignal/cross_decomposition/TRCA.py: -------------------------------------------------------------------------------- 1 | """ 2 | Work in progress 3 | Taken from https://github.com/iuype/TRCA-SSVEP/blob/master/Main.py 4 | """ 5 | 6 | # ''' 7 | # Author: 8 | # Yu Pei, 1666424499@qq.com 9 | # Versions: 10 | # v1.0: 2019-12-12, 11 | # V1.1: 2019-12-15, fix the bug : det(Q) = 0 12 | # ''' 13 | # import os 14 | # import sys 15 | # import argparse 16 | # from scipy.io import loadmat 17 | # import glob 18 | # import numpy as np 19 | # from sklearn.model_selection import train_test_split 20 | # from scipy import signal 21 | # ''' 22 | # (1).高通滤波 23 | # #这里假设采样频率为1000hz,信号本身最大的频率为500hz,要滤除10hz以下频率成分,即截至频率为10hz,则wn=2*10/1000=0.02 24 | # from scipy import signal 25 | # b, a = signal.butter(8, 0.02, 'highpass') 26 | # filtedData = signal.filtfilt(b, a, data)#data为要过滤的信号 27 | # (2).低通滤波 28 | # #这里假设采样频率为1000hz,信号本身最大的频率为500hz,要滤除10hz以上频率成分,即截至频率为10hz,则wn=2*10/1000=0.02 29 | # from scipy import signal 30 | # b, a = signal.butter(8, 0.02, 'lowpass') 31 | # filtedData = signal.filtfilt(b, a, data) #data为要过滤的信号 32 | 33 | # (3).带通滤波 34 | # #这里假设采样频率为1000hz,信号本身最大的频率为500hz,要滤除10hz以下和400hz以上频率成分,即截至频率为10hz和400hz,则wn1=2*10/1000=0.02,wn2=2*400/1000=0.8。Wn=[0.02,0.8] 35 | # from scipy import signal 36 | # b, a = signal.butter(8, [0.02,0.8], 'bandpass') 37 | # filtedData = signal.filtfilt(b, a, data) #data为要过滤的信号 38 | # ———————————————— 39 | # 版权声明:本文为CSDN博主「John-Cao」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。 40 | # 原文链接:https://blog.csdn.net/weixin_37996604/article/details/82864680 41 | # ''' 42 | 43 | # class TRCA(): 44 | # def __init__(self, opt): 45 | # self.opt = opt 46 | # self.channels = ["POz", "PO3", "PO4", "PO5", "PO6", "Oz", "O1", "O2"] 47 | # self.sample_rate = 1000 # 1000Hz 48 | # self.downsample_rate= 250 # 250Hz 49 | 50 | # self.latency = 0.14 # [0.14s, 0.14s + d] 51 | # self.data_length = self.opt.data_length # d , data length 52 | 53 | # self.traindata = None 54 | # self.trainlabel = None 55 | # self.testdata = None 56 | # self.testlabel = None 57 | 58 | # self.Nm = self.opt.Nm 59 | # self.Nf = self.opt.Nf 60 | # self.Nc = len(self.channels) 61 | # self.X_hat = np.zeros((self.Nm, self.Nf, self.Nc, int(self.downsample_rate * self.data_length))) # (Nm, Nf, Nc, data_length) 62 | 63 | # self.W = np.zeros((self.Nm, self.Nf, len(self.channels), 1)) # (Nm, Nf, Nc, 1) 64 | 65 | # self.w1 = [2*(m+1)*8/self.downsample_rate for m in range(self.Nm)] 66 | # self.w2 = [2 * 90 / self.downsample_rate for m in range(self.Nm)] 67 | 68 | 69 | # def load_data(self, dataroot = None): 70 | # if dataroot is None: 71 | # print("dataroot error -->: ",dataroot) 72 | 73 | # datapath = glob.glob(os.path.join(dataroot,"xxrAR","*")) # ['.\\datasets\\xxrAR\\EEG.mat'] 74 | 75 | # oEEG = loadmat(datapath[0]) 76 | 77 | # """ 78 | # data 数据格式(trials, filter_bank, channals, timep_oints) --> (160, Nm, Nf, 0.5s * 采样率) 79 | # """ 80 | # self.traindata, self.testdata, self.trainlabel, self.testlabel = self.segment(oEEG) 81 | 82 | # def segment(self , oEEG): 83 | # EEG = oEEG["EEG"] 84 | # data = EEG["data"][0, 0] 85 | # event = EEG["event"][0, 0] 86 | 87 | # chanlocs = EEG["chanlocs"][0,0] 88 | 89 | # channels_idx = [] 90 | 91 | # for i in range(chanlocs.shape[0]): 92 | # if chanlocs[i,0][0][0] in self.channels: 93 | # channels_idx.append(i) 94 | 95 | # all_data = np.zeros( 96 | # ( 97 | # 160, # 样本个数 98 | # self.Nc, # 通道数 99 | # int(self.downsample_rate*self.data_length) # 数据长度 100 | # ) 101 | # ) 102 | 103 | # all_label = np.zeros((160, 1)) # [0,1,2,...,Nf] 取值范围 104 | 105 | # for idx in range(event.shape[0]): 106 | # lb = int(event["type"][idx, 0][0]) - 1 107 | # lat = event["latency"][idx, 0][0][0] 108 | 109 | # # 原始数据为 1000hz 采样 , 根据原文,需要降采样到 250hz 。 110 | # all_data[idx] = data[channels_idx, int(self.latency* self.sample_rate + lat ) : int(self.latency* self.sample_rate + lat + self.data_length * self.sample_rate )][:,::4] 111 | # all_label[idx,0] = lb 112 | 113 | # all_data = self.filerbank(all_data) 114 | 115 | # # 9 : 1 分割训练集与测试集 116 | # train_X, test_X, train_y, test_y = train_test_split(all_data, 117 | # all_label, 118 | # test_size = 0.2, 119 | # random_state = 0) 120 | # return train_X, test_X, train_y, test_y 121 | 122 | # ''' 123 | # 这里假设采样频率为1000hz,信号本身最大的频率为500hz,要滤除10hz以下和400hz以上频率成分,即截至频率为10hz和400hz,则wn1=2*10/1000=0.02,wn2=2*400/1000=0.8。Wn=[0.02,0.8] 124 | # b, a = signal.butter(8, [0.02,0.8], 'bandpass') 125 | # filtedData = signal.filtfilt(b, a, data) #data为要过滤的信号 126 | # ''' 127 | # def filerbank(self, data): 128 | # data_filterbank = np.zeros((data.shape[0], self.Nm, len(self.channels), data.shape[2])) 129 | 130 | # for i in range(self.Nm): 131 | # # 8 是滤波器的阶数, 不确定用哪个。。 132 | # # print([self.w1[i], self.w2[i]]) 133 | # b, a = signal.butter(self.opt.filter_order, [self.w1[i], self.w2[i]], 'bandpass') 134 | # data_filterbank[:, i, :, :] = signal.filtfilt(b, a, data, axis= -1) 135 | 136 | # return data_filterbank 137 | 138 | # def Cov(self,X,Y): 139 | 140 | # X = X.reshape(-1, 1) 141 | # Y = Y.reshape(-1, 1) 142 | 143 | # # print(X.shape, Y.shape) 144 | 145 | # X_hat = np.mean(X) 146 | # Y_hat = np.mean(Y) 147 | 148 | # X = X - X_hat 149 | # Y = Y - Y_hat 150 | 151 | # ret = np.dot(X.T ,Y) 152 | # ret /= (X.shape[0]) 153 | 154 | # return ret 155 | 156 | 157 | # def fit(self): 158 | 159 | # S = np.zeros((self.Nm, self.Nf, self.Nc, self.Nc)) 160 | # Q = np.zeros((self.Nm, self.Nf, self.Nc, self.Nc)) 161 | 162 | # # S 163 | # for m in range(self.Nm): 164 | # for n in range(self.Nf): 165 | # idxs = [] # stimulus n 的索引 166 | # for i in range(self.traindata.shape[0]): 167 | # if self.trainlabel[i, 0] == n: 168 | # idxs.append(i) 169 | # for j1 in range(self.Nc): 170 | # for j2 in range(self.Nc): 171 | # for h1 in idxs: 172 | # for h2 in idxs: 173 | # if h1 != h2: 174 | # S[m, n, j1, j2] += self.Cov(self.traindata[h1, m, j1, :], self.traindata[h2, m, j2, :]) 175 | # # print(S[m,n]) # 检查 S是对称的,没有问题 176 | 177 | # # Q 178 | # for m in range(self.Nm): 179 | # for n in range(self.Nf): 180 | # idxs = [] # stimulus n 的索引 181 | # for i in range(self.traindata.shape[0]): 182 | # if self.trainlabel[i, 0] == n: 183 | # idxs.append(i) 184 | # for h in idxs: 185 | # for j1 in range(self.Nc): 186 | # for j2 in range(self.Nc): 187 | # Q[m, n, j1, j2] += self.Cov(self.traindata[h, m, j1, :], self.traindata[h, m, j2, :]) 188 | # Q[m, n] /= len(idxs) 189 | # # print(Q[m, n]) # 发现bug det(Q) = 0 ... 我日 190 | 191 | # for m in range(self.Nm): 192 | # for n in range(self.Nf): 193 | 194 | # e_vals, e_vecs = np.linalg.eig(np.linalg.inv(Q[m, n]).dot(S[m, n])) 195 | 196 | # max_e_vals_idx = np.argmax(e_vals) 197 | 198 | # self.W[m, n, :, 0] = e_vecs[:, max_e_vals_idx] 199 | 200 | # # calculate hat 201 | # for m in range(self.Nm): 202 | # for n in range(self.Nf): 203 | # idxs = [] # stimulus n 的索引 204 | # for i in range(self.traindata.shape[0]): 205 | # if self.trainlabel[i, 0] == n: 206 | # idxs.append(i) 207 | 208 | # for h in idxs: 209 | 210 | # self.X_hat[m,n] += self.traindata[h, m] # (8, 125) 211 | 212 | # self.X_hat[m, n] /= len(idxs) 213 | 214 | 215 | # tot = 0 216 | # tot_correct = 0 217 | 218 | # for i in range(self.testdata.shape[0]): 219 | # pre_lb , lb = self.inference(self.testdata[i]),self.testlabel[i,0] 220 | # if pre_lb == lb: 221 | # tot_correct += 1 222 | # tot += 1 223 | 224 | # print(tot_correct / tot) 225 | 226 | 227 | # tot = 0 228 | # tot_correct = 0 229 | 230 | # for i in range(self.traindata.shape[0]): 231 | # pre_lb , lb = self.inference(self.traindata[i]),self.trainlabel[i,0] 232 | # if pre_lb == lb: 233 | # tot_correct += 1 234 | # tot += 1 235 | 236 | # print(tot_correct / tot) 237 | 238 | 239 | # def inference(self, X): # (Nm ,Nc, data_length) 240 | # r = np.zeros((self.Nm, self.Nf)) 241 | 242 | # for m in range(self.Nm): 243 | # for n in range(self.Nf): 244 | # r[m, n] = self.pearson_corr_1D(X[m].T.dot(self.W[m, n]), self.X_hat[m, n].T.dot(self.W[m, n])) 245 | 246 | # Pn = np.zeros(self.Nf) 247 | # for n in range(self.Nf): 248 | # for m in range(self.Nm): 249 | # Pn[n] += ((m+1) ** (-1.25) + 0.25 ) * (r[m, n] ** 2) 250 | 251 | # pre_label = np.argmax(Pn) 252 | 253 | # return pre_label 254 | 255 | # def pearson_corr_1D(self, a, b): 256 | # a = a.reshape(-1) 257 | # b = b.reshape(-1) 258 | # ret = self.Cov(a,b) / (np.std(a) * np.std(b)) 259 | # return ret 260 | 261 | 262 | # def pearson_corr_2D(self, a, b): 263 | # """ 264 | # todo 265 | # 2维皮尔逊相关系数 266 | # 两个变量之间的皮尔逊相关系数定义为两个变量之间的协方差和标准差的商 267 | # """ 268 | # return 0.5 269 | 270 | # def __del__(self): 271 | # pass 272 | 273 | 274 | 275 | 276 | # if __name__ == '__main__': 277 | # parser = argparse.ArgumentParser() 278 | # parser.add_argument("--epochs", type=int, default=100, help="number of epochs") 279 | 280 | # parser.add_argument("--dataroot", type=str, default=os.path.join(".", "datasets"), help="the folder of data") 281 | # parser.add_argument("--filter_order", type=int, default=8, help="order of filter") 282 | # parser.add_argument("--Nm", type=int, default = 7, help="number of bank") 283 | # parser.add_argument("--data_length", type=float, default=0.5, help="task time points") 284 | # parser.add_argument("--Nf", type=int, default=8, help="number of stimulus") 285 | 286 | # opt = parser.parse_args() 287 | # print(opt) 288 | 289 | # trca = TRCA(opt) 290 | # trca.load_data(opt.dataroot) 291 | # trca.fit() 292 | 293 | # print("done!") -------------------------------------------------------------------------------- /torchsignal/cross_decomposition/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinglescode/torchsignal/6172bc2b18eeafa9464cfba678e9c02ea4ed5e2a/torchsignal/cross_decomposition/__init__.py -------------------------------------------------------------------------------- /torchsignal/cross_decomposition/reference_frequencies.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_reference_signals(target_frequencies, duration, sample_rate, num_harmonics=2): 5 | r""" 6 | Generating a single sinusoidal template for SSVEP classification 7 | 8 | Args: 9 | target_frequencies : array 10 | Frequencies for SSVEP classification 11 | duration : int 12 | Window/segment length in time samples 13 | sample_rate : int 14 | Sampling frequency 15 | num_harmonics : int, default: 2 16 | Generate till n-th harmonics 17 | Returns: 18 | reference_signals : ndarray, shape (len(flick_freq),4,time) 19 | Reference frequency signals 20 | Example: 21 | Refer to `generate_reference_signals()` 22 | Dependencies: 23 | np : numpy package 24 | """ 25 | 26 | reference_signals = [] 27 | t = np.arange(0, (duration/sample_rate), step=1.0/sample_rate) 28 | 29 | for i in range(1, num_harmonics+1): 30 | j = i*2 31 | reference_signals.append(np.sin(np.pi*j*target_frequencies*t)) 32 | reference_signals.append(np.cos(np.pi*j*target_frequencies*t)) 33 | 34 | reference_signals = np.array(reference_signals) 35 | return reference_signals 36 | 37 | 38 | def generate_reference_signals(flick_freq, duration, sample_rate, num_harmonics=2): 39 | r""" 40 | Generating the required sinusoidal templates for SSVEP classification 41 | 42 | Args: 43 | flick_freq : array 44 | Frequencies for SSVEP classification 45 | duration : int 46 | Window/segment length in time samples 47 | sample_rate : int 48 | Sampling frequency 49 | num_harmonics : int 50 | Generate till n-th harmonics 51 | Returns: 52 | reference_signals : ndarray, shape (len(flick_freq),2*num_harmonics,time) 53 | Reference frequency signals 54 | Example: 55 | reference_frequencies = generate_reference_signals( 56 | [5,7.5,10,12], duration=4000, sample_rate=1000, num_harmonics=3) 57 | Dependencies: 58 | np : numpy package 59 | get_reference_frequencies : function 60 | """ 61 | 62 | reference_frequencies = [] 63 | for fr in range(0, len(flick_freq)): 64 | ref = get_reference_signals(flick_freq[fr], duration, sample_rate, num_harmonics) 65 | reference_frequencies.append(ref) 66 | reference_frequencies = np.array(reference_frequencies, dtype='float32') 67 | 68 | return reference_frequencies -------------------------------------------------------------------------------- /torchsignal/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .openbmi import OPENBMI 2 | 3 | __all__ = ( 4 | "OPENBMI", 5 | ) -------------------------------------------------------------------------------- /torchsignal/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | 4 | 5 | class PyTorchDataset(Dataset): 6 | def __init__(self, data, targets): 7 | self.data = data 8 | self.data = self.data.astype(np.float32) 9 | self.targets = targets 10 | self.channel_names = None 11 | 12 | def __getitem__(self, index): 13 | return self.data[index], self.targets[index] 14 | 15 | def __len__(self): 16 | return len(self.data) 17 | 18 | def set_data_targets(self, data: [] = None, targets: [] = None) -> None: 19 | if data is not None: 20 | self.data = data.copy() 21 | if targets is not None: 22 | self.targets = targets.copy() 23 | 24 | def set_channel_names(self,channel_names): 25 | self.channel_names = channel_names -------------------------------------------------------------------------------- /torchsignal/datasets/generate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def generate_signal(length_seconds, sampling_rate, frequencies_list, func="sin", add_noise=0, plot=True): 6 | r""" 7 | Generate a `length_seconds` seconds signal at `sampling_rate` sampling rate. 8 | 9 | Args: 10 | length_seconds : int 11 | Duration of signal in seconds (i.e. `10` for a 10-seconds signal) 12 | sampling_rate : int 13 | The sampling rate of the signal. 14 | frequencies_list : 1 or 2 dimension python list a floats 15 | An array of floats, where each float is the desired frequencies to generate (i.e. [5, 12, 15] to generate a signal containing a 5-Hz, 12-Hz and 15-Hz) 16 | 2 dimension python list, i.e. [[5, 12, 15],[1]], to generate a signal with 2 signals, where the second channel containing 1-Hz signal 17 | func : string, default: sin 18 | The periodic function to generate signal, either `sin` or `cos` 19 | add_noise : float, default: 0 20 | Add random noise to the signal, where `0` has no noise 21 | plot : boolean 22 | Plot the generated signal 23 | Returns: 24 | signal : 1d ndarray 25 | Generated signal, a numpy array of length `sampling_rate*length_seconds` 26 | """ 27 | 28 | frequencies_list = np.array(frequencies_list, dtype=object) 29 | assert len(frequencies_list.shape) == 1 or len(frequencies_list.shape) == 2, "frequencies_list must be 1d or 2d python list" 30 | 31 | expanded = False 32 | if isinstance(frequencies_list[0], int): 33 | frequencies_list = np.expand_dims(frequencies_list, axis=0) 34 | expanded = True 35 | 36 | npnts = sampling_rate*length_seconds # number of time samples 37 | time = np.arange(0, npnts)/sampling_rate 38 | signal = np.zeros((frequencies_list.shape[0],npnts)) 39 | 40 | for channel in range(0,frequencies_list.shape[0]): 41 | for fi in frequencies_list[channel]: 42 | if func == "cos": 43 | signal[channel] = signal[channel] + np.cos(2*np.pi*fi*time) 44 | else: 45 | signal[channel] = signal[channel] + np.sin(2*np.pi*fi*time) 46 | 47 | # normalize 48 | max = np.repeat(signal[channel].max()[np.newaxis], npnts) 49 | min = np.repeat(signal[channel].min()[np.newaxis], npnts) 50 | signal[channel] = (2*(signal[channel]-min)/(max-min))-1 51 | 52 | if add_noise: 53 | noise = np.random.uniform(low=0, high=add_noise, size=(frequencies_list.shape[0],npnts)) 54 | signal = signal + noise 55 | 56 | if plot: 57 | plt.plot(time, signal.T) 58 | plt.show() 59 | 60 | if expanded: 61 | signal = signal[0] 62 | 63 | return signal 64 | -------------------------------------------------------------------------------- /torchsignal/datasets/hsssvep.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.io as sio 4 | from typing import Tuple 5 | 6 | from torchsignal.datasets.dataset import PyTorchDataset 7 | 8 | 9 | class HSSSVEP(PyTorchDataset): 10 | """ 11 | This is a private dataset. 12 | A Benchmark Dataset for SSVEP-Based Brain–Computer Interfaces 13 | Yijun Wang, Xiaogang Chen, Xiaorong Gao, Shangkai Gao 14 | https://ieeexplore.ieee.org/document/7740878 15 | Sampling rate: 250 Hz 16 | Targets: [8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,8.2,9.2,10.2,11.2,12.2,13.2,14.2,15.2,8.4,9.4,10.4,11.4,12.4,13.4,14.4,15.4,8.6,9.6,10.6,11.6,12.6,13.6,14.6,15.6,8.8,9.8,10.8,11.8,12.8,13.8,14.8,15.8] 17 | 18 | This dataset gathered SSVEP-BCI recordings of 35 healthy subjects (17 females, aged 17-34 years, mean age: 22 years) focusing on 40 characters flickering at different frequencies (8-15.8 Hz with an interval of 0.2 Hz). For each subject, the experiment consisted of 6 blocks. Each block contained 40 trials corresponding to all 40 characters indicated in a random order. Each trial started with a visual cue (a red square) indicating a target stimulus. The cue appeared for 0.5 s on the screen. Subjects were asked to shift their gaze to the target as soon as possible within the cue duration. Following the cue offset, all stimuli started to flicker on the screen concurrently and lasted 5 s. After stimulus offset, the screen was blank for 0.5 s before the next trial began, which allowed the subjects to have short breaks between consecutive trials. Each trial lasted a total of 6 s. To facilitate visual fixation, a red triangle appeared below the flickering target during the stimulation period. In each block, subjects were asked to avoid eye blinks during the stimulation period. To avoid visual fatigue, there was a rest for several minutes between two consecutive blocks. 19 | 20 | EEG data were acquired using a Synamps2 system (Neuroscan, Inc.) with a sampling rate of 1000 Hz. The amplifier frequency passband ranged from 0.15 Hz to 200 Hz. Sixty-four channels covered the whole scalp of the subject and were aligned according to the international 10-20 system. The ground was placed on midway between Fz and FPz. The reference was located on the vertex. Electrode impedances were kept below 10 K". To remove the common power-line noise, a notch filter at 50 Hz was applied in data recording. Event triggers generated by the computer to the amplifier and recorded on an event channel synchronized to the EEG data. 21 | 22 | The continuous EEG data was segmented into 6 s epochs (500 ms pre-stimulus, 5.5 s post-stimulus onset). The epochs were subsequently downsampled to 250 Hz. Thus each trial consisted of 1500 time points. Finally, these data were stored as double-precision floating-point values in MATLAB and were named as subject indices (i.e., S01.mat, ", S35.mat). For each file, the data loaded in MATLAB generate a 4-D matrix named "data" with dimensions of [64, 1500, 40, 6]. The four dimensions indicate "Electrode index", "Time points", "Target index", and "Block index". The electrode positions were saved in a "64-channels.loc" file. Six trials were available for each SSVEP frequency. Frequency and phase values for the 40 target indices were saved in a "Freq_Phase.mat" file. 23 | 24 | Information for all subjects was listed in a "Sub_info.txt" file. For each subject, there are five factors including "Subject Index", "Gender", "Age", "Handedness", and "Group". Subjects were divided into an "experienced" group (eight subjects, S01-S08) and a "naive" group (27 subjects, S09-S35) according to their experience in SSVEP-based BCIs. 25 | """ 26 | 27 | def __init__(self, root: str, subject_id: int, verbose: bool = False) -> None: 28 | 29 | self.root = root 30 | self.sample_rate = 1000 31 | self.data, self.targets, self.channel_names = _load_data(self.root, subject_id, verbose) 32 | 33 | def __getitem__(self, n: int) -> Tuple[np.ndarray, int]: 34 | return (self.data[n], self.targets[n]) 35 | 36 | def __len__(self) -> int: 37 | return len(self.data) 38 | 39 | 40 | def _load_data(root, subject_id, verbose): 41 | 42 | path = os.path.join(root, 'S'+str(subject_id)+'.mat') 43 | data_mat = sio.loadmat(path) 44 | 45 | raw_data = data_mat['data'].copy() 46 | raw_data = np.transpose(raw_data, (2,3,0,1)) 47 | 48 | data = [] 49 | targets = [] 50 | for target_id in np.arange(raw_data.shape[0]): 51 | data.extend(raw_data[target_id]) 52 | 53 | this_target = np.array([target_id]*raw_data.shape[1]) 54 | targets.extend(this_target) 55 | 56 | # Each trial started with a 0.5-s target cue. Subjects were asked to shift their gaze to the target as soon as possible. After the cue, all stimuli started to flicker on the screen concurrently for 5 s. Then, the screen was blank for 0.5 s before the next trial began. Each trial lasted 6 s in total. 57 | # We start from 160, because 0.5s Cue + 0.14s (visual latency) as they use phase in stimulus presentation. 0.64*250 = 160 58 | # We also cut the signal off after 4 seconds 59 | data = np.array(data)[:,:,160:1160] 60 | 61 | targets = np.array(targets) 62 | 63 | channel_names = ['FP1','FPZ','FP2','AF3','AF4','F7','F5','F3','F1','FZ','F2','F4','F6','F8','FT7','FC5','FC3','FC1','FCz','FC2','FC4','FC6','FT8','T7','C5','C3','C1','Cz','C2','C4','C6','T8','M1','TP7','CP5','CP3','CP1','CPZ','CP2','CP4','CP6','TP8','M2','P7','P5','P3','P1','PZ','P2','P4','P6','P8','PO7','PO5','PO3','POz','PO4','PO6','PO8','CB1','O1','Oz','O2','CB2'] 64 | 65 | if verbose: 66 | print('Load path:', path) 67 | print('Data shape', data.shape) 68 | print('Targets shape', targets.shape) 69 | 70 | return data, targets, channel_names -------------------------------------------------------------------------------- /torchsignal/datasets/multiplesubjects.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | 4 | from torchsignal.datasets import OPENBMI 5 | from torchsignal.filter.channels import pick_channels 6 | from torchsignal.filter.butterworth import butter_bandpass_filter 7 | from torchsignal.transform.segment import segment_signal 8 | from torchsignal.datasets.utils import onehot_targets 9 | from torchsignal.datasets.dataset import PyTorchDataset 10 | from torchsignal.datasets.utils import train_test_split, dataset_split_stratified 11 | 12 | 13 | class MultipleSubjects(): 14 | 15 | def __init__(self, 16 | dataset: PyTorchDataset, 17 | root: str, 18 | subject_ids: [], 19 | sessions: [] = None, 20 | selected_channels: [] = None, 21 | segment_config: {} = None, 22 | bandpass_config: {} = None, 23 | one_hot_labels: bool = False, 24 | verbose: bool = False, 25 | ) -> None: 26 | 27 | self.train_dataset_by_subjects = None 28 | self.val_dataset_by_subjects = None 29 | 30 | self.one_hot_labels = one_hot_labels 31 | 32 | self.data_by_subjects = _load_multiple( 33 | root=root, 34 | dataset=dataset, 35 | subject_ids=subject_ids, 36 | sessions=sessions 37 | ) 38 | 39 | _process_data( 40 | data_by_subjects=self.data_by_subjects, 41 | selected_channels=selected_channels, 42 | segment_config=segment_config, 43 | bandpass_config=bandpass_config, 44 | ) 45 | 46 | def split_by_kfold(self, kfold_k=0, kfold_split=5): 47 | self.train_dataset_by_subjects, self.val_dataset_by_subjects = _data_split_stratified(self.data_by_subjects, kfold_k, kfold_split) 48 | if self.one_hot_labels: 49 | _one_hot_labels(self.data_by_subjects, self.train_dataset_by_subjects, self.val_dataset_by_subjects) 50 | 51 | 52 | def leave_one_subject_out(self, selected_subject_id=1, dataloader_batchsize=32, dataloader_shuffle=True): 53 | 54 | assert selected_subject_id in self.data_by_subjects, "Must select subjects in dataset" 55 | 56 | if self.train_dataset_by_subjects is None: 57 | self.split_by_kfold() 58 | 59 | # selected subject 60 | # selected_subject_x = self.data_by_subjects[selected_subject_id].data 61 | # selected_subject_y = self.data_by_subjects[selected_subject_id].targets 62 | 63 | selected_subject_x = np.concatenate((self.train_dataset_by_subjects[selected_subject_id].data, self.val_dataset_by_subjects[selected_subject_id].data), axis=0) 64 | selected_subject_y = np.concatenate((self.train_dataset_by_subjects[selected_subject_id].targets, self.val_dataset_by_subjects[selected_subject_id].targets), axis=0) 65 | test_dataset = PyTorchDataset(selected_subject_x, selected_subject_y) 66 | 67 | # the rest 68 | other_subjects_x_train = [] 69 | other_subjects_y_train = [] 70 | other_subjects_x_val = [] 71 | other_subjects_y_val = [] 72 | 73 | for subject_id in list(self.data_by_subjects.keys()): 74 | if subject_id != selected_subject_id: 75 | other_subjects_x_train.extend(self.train_dataset_by_subjects[subject_id].data) 76 | other_subjects_y_train.extend(self.train_dataset_by_subjects[subject_id].targets) 77 | 78 | other_subjects_x_val.extend(self.val_dataset_by_subjects[subject_id].data) 79 | other_subjects_y_val.extend(self.val_dataset_by_subjects[subject_id].targets) 80 | 81 | other_subjects_x_train = np.array(other_subjects_x_train) 82 | other_subjects_y_train = np.array(other_subjects_y_train) 83 | other_subjects_x_val = np.array(other_subjects_x_val) 84 | other_subjects_y_val = np.array(other_subjects_y_val) 85 | 86 | train_dataset = PyTorchDataset(other_subjects_x_train, other_subjects_y_train) 87 | val_dataset = PyTorchDataset(other_subjects_x_val, other_subjects_y_val) 88 | 89 | # data loader 90 | train_loader = DataLoader(train_dataset, batch_size=dataloader_batchsize, shuffle=dataloader_shuffle) 91 | val_loader = DataLoader(val_dataset, batch_size=dataloader_batchsize, shuffle=dataloader_shuffle) 92 | selected_subject_loader = DataLoader(test_dataset, batch_size=dataloader_batchsize, shuffle=False) 93 | 94 | return train_loader, val_loader, selected_subject_loader 95 | 96 | 97 | def _load_multiple(root, dataset: PyTorchDataset, subject_ids: [], sessions: [], verbose: bool = False) -> None: 98 | data_by_subjects = {} 99 | 100 | for subject_id in subject_ids: 101 | print('Load subject:', subject_id) 102 | subject_data = None 103 | subject_target = None 104 | 105 | if sessions: 106 | for session in sessions: 107 | subject_dataset = dataset(root=root, subject_id=subject_id, session=session) 108 | 109 | if subject_data is None: # if its session #1, will be None 110 | subject_data = np.zeros((0, subject_dataset.data.shape[1], subject_dataset.data.shape[2])) 111 | subject_target = np.zeros((0, )) 112 | 113 | subject_data = np.concatenate((subject_data, subject_dataset.data)) 114 | subject_target = np.concatenate((subject_target, subject_dataset.targets)) 115 | else: 116 | subject_dataset = dataset(root=root, subject_id=subject_id) 117 | 118 | if subject_data is None: # if its session #1, will be None 119 | subject_data = np.zeros((0, subject_dataset.data.shape[1], subject_dataset.data.shape[2])) 120 | subject_target = np.zeros((0, )) 121 | 122 | subject_data = np.concatenate((subject_data, subject_dataset.data)) 123 | subject_target = np.concatenate((subject_target, subject_dataset.targets)) 124 | 125 | subject_target = subject_target.astype(np.long) 126 | subject_dataset_new = PyTorchDataset(data=subject_data, targets=subject_target) 127 | subject_dataset_new.set_channel_names(subject_dataset.channel_names) 128 | data_by_subjects[subject_id] = subject_dataset_new 129 | 130 | return data_by_subjects 131 | 132 | 133 | def _process_data(data_by_subjects, selected_channels, segment_config, bandpass_config): 134 | 135 | for subject_id in list(data_by_subjects.keys()): 136 | subject_dataset = data_by_subjects[subject_id] 137 | 138 | subject_data = subject_dataset.data 139 | 140 | # filter channels 141 | if selected_channels is not None: 142 | subject_data = pick_channels( 143 | data=subject_data, 144 | channel_names=subject_dataset.channel_names, 145 | selected_channels=selected_channels 146 | ) 147 | 148 | # segment signal 149 | if segment_config is not None: 150 | subject_data = segment_signal( 151 | signal=subject_data, 152 | window_len=segment_config['window_len'], 153 | shift_len=segment_config['shift_len'], 154 | sample_rate=segment_config['sample_rate'], 155 | add_segment_axis=segment_config['add_segment_axis'], 156 | ) 157 | 158 | subject_data_full = np.zeros((subject_data.shape[0], subject_data.shape[1], subject_data.shape[3])) 159 | 160 | for trial in range(0, subject_data_full.shape[0]): 161 | for channel in range(0, subject_data_full.shape[1]): 162 | subject_data_full[trial, channel, :] = subject_data[trial, channel, 0, :] 163 | 164 | subject_data = subject_data_full 165 | 166 | # filter by bandpass 167 | if bandpass_config is not None: 168 | subject_data = butter_bandpass_filter(subject_data, lowcut=bandpass_config["lowcut"], highcut=bandpass_config["highcut"], sample_rate=bandpass_config["sample_rate"], order=bandpass_config["order"]) 169 | 170 | subject_dataset.set_data_targets(data=subject_data) 171 | 172 | 173 | def _train_test_dataset(data_by_subjects): 174 | train_dataset_by_subjects = {} 175 | test_dataset_by_subjects = {} 176 | 177 | for subject_id in list(data_by_subjects.keys()): 178 | train_dataset, test_dataset = train_test_split(data_by_subjects[subject_id].data, data_by_subjects[subject_id].targets) 179 | 180 | train_dataset_by_subjects[subject_id] = train_dataset 181 | test_dataset_by_subjects[subject_id] = test_dataset 182 | 183 | return train_dataset_by_subjects, test_dataset_by_subjects 184 | 185 | 186 | def _data_split_stratified(data_by_subjects, kfold_k, kfold_split): 187 | train_dataset_by_subjects = {} 188 | test_dataset_by_subjects = {} 189 | 190 | for subject_id in list(data_by_subjects.keys()): 191 | data = dataset_split_stratified(data_by_subjects[subject_id].data, data_by_subjects[subject_id].targets, k=kfold_k, n_splits=kfold_split, pytorch_dataset_object=PyTorchDataset) 192 | for i in range(len(data)): 193 | if i == 0: 194 | train_dataset_by_subjects[subject_id] = data[i] 195 | elif i == 1: 196 | test_dataset_by_subjects[subject_id] = data[i] 197 | 198 | return train_dataset_by_subjects, test_dataset_by_subjects 199 | 200 | 201 | def _one_hot_labels(data_by_subjects, train_dataset_by_subjects, val_dataset_by_subjects): 202 | 203 | num_class = len(list(set(data_by_subjects[1].targets))) 204 | 205 | for subject_id in list(train_dataset_by_subjects.keys()): 206 | dataset = train_dataset_by_subjects[subject_id] 207 | 208 | dataset_targets = onehot_targets(dataset.targets, num_class=num_class) 209 | dataset.set_data_targets(targets=dataset_targets) 210 | 211 | for subject_id in list(val_dataset_by_subjects.keys()): 212 | dataset = val_dataset_by_subjects[subject_id] 213 | 214 | dataset_targets = onehot_targets(dataset.targets, num_class=num_class) 215 | dataset.set_data_targets(targets=dataset_targets) -------------------------------------------------------------------------------- /torchsignal/datasets/openbmi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.io as sio 4 | from typing import Tuple 5 | 6 | from torchsignal.datasets.dataset import PyTorchDataset 7 | 8 | 9 | class OPENBMI(PyTorchDataset): 10 | """ 11 | This is a private dataset. 12 | EEG dataset and OpenBMI toolbox for three BCI paradigms: an investigation into BCI illiteracy. 13 | Min-Ho Lee, O-Yeon Kwon, Yong-Jeong Kim, Hong-Kyung Kim, Young-Eun Lee, John Williamson, Siamac Fazli, Seong-Whan Lee. 14 | https://academic.oup.com/gigascience/article/8/5/giz002/5304369 15 | Target frequencies: 5.45, 6.67, 8.57, 12 Hz 16 | Sampling rate: 1000 Hz 17 | """ 18 | 19 | def __init__(self, root: str, subject_id: int, session: int, verbose: bool = False) -> None: 20 | 21 | self.root = root 22 | self.sample_rate = 1000 23 | self.data, self.targets, self.channel_names = _load_data( 24 | self.root, subject_id, session, verbose) 25 | 26 | def __getitem__(self, n: int) -> Tuple[np.ndarray, int]: 27 | return (self.data[n], self.targets[n]) 28 | 29 | def __len__(self) -> int: 30 | return len(self.data) 31 | 32 | 33 | def _load_data(root, subject_id, session, verbose): 34 | 35 | path = os.path.join(root, 'session'+str(session), 36 | 's'+str(subject_id)+'/EEG_SSVEP.mat') 37 | 38 | data_mat = sio.loadmat(path) 39 | 40 | objects_in_mat = [] 41 | for i in data_mat['EEG_SSVEP_train'][0][0]: 42 | objects_in_mat.append(i) 43 | 44 | # data 45 | data = objects_in_mat[0][:, :, :].copy() 46 | data = np.transpose(data, (1, 2, 0)) 47 | data = data.astype(np.float32) 48 | 49 | # label 50 | targets = [] 51 | for i in range(data.shape[0]): 52 | targets.append([objects_in_mat[2][0][i], 0, objects_in_mat[4][0][i]]) 53 | targets = np.array(targets) 54 | targets = targets[:, 2] 55 | targets = targets-1 56 | 57 | # channel 58 | channel_names = [v[0] for v in objects_in_mat[8][0]] 59 | 60 | if verbose: 61 | print('Load path:', path) 62 | print('Objects in .mat', len(objects_in_mat), 63 | data_mat['EEG_SSVEP_train'].dtype.descr) 64 | print() 65 | print('Data shape', data.shape) 66 | print('Targets shape', targets.shape) 67 | 68 | return data, targets, channel_names -------------------------------------------------------------------------------- /torchsignal/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.model_selection import StratifiedKFold 3 | from torchsignal.datasets.dataset import PyTorchDataset 4 | 5 | 6 | def train_test_split(X, y): 7 | train, test = dataset_split_stratified( 8 | X, y, k=0, n_splits=4, seed=71, shuffle=True, pytorch_dataset_object=PyTorchDataset) 9 | return train, test 10 | 11 | 12 | def dataset_split_stratified(X, y, k=-1, n_splits=3, seed=71, shuffle=True, pytorch_dataset_object=None): 13 | return_data = [] 14 | skf = StratifiedKFold( 15 | n_splits=n_splits, random_state=seed, shuffle=shuffle) 16 | split_data = skf.split(X, y) 17 | 18 | for train_index, test_index in split_data: 19 | X_train, X_test = X[train_index], X[test_index] 20 | y_train, y_test = y[train_index], y[test_index] 21 | 22 | if pytorch_dataset_object is not None: 23 | return_data.append(pytorch_dataset_object(X_train, y_train)) 24 | return_data.append(pytorch_dataset_object(X_test, y_test)) 25 | else: 26 | return_data.append((X_train, y_train)) 27 | return_data.append((X_test, y_test)) 28 | 29 | if k == -1: 30 | return tuple(return_data) 31 | else: 32 | return tuple(return_data)[k*2:k*2+2] 33 | 34 | 35 | def onehot_targets(targets, num_class=4): 36 | onehot_y = np.zeros((targets.shape[0], num_class)) 37 | onehot_y[np.arange(onehot_y.shape[0]), targets] = 1 38 | onehot_y = onehot_y.astype(np.long) 39 | return onehot_y 40 | -------------------------------------------------------------------------------- /torchsignal/filter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinglescode/torchsignal/6172bc2b18eeafa9464cfba678e9c02ea4ed5e2a/torchsignal/filter/__init__.py -------------------------------------------------------------------------------- /torchsignal/filter/butterworth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from scipy.signal import butter, filtfilt, sosfiltfilt, freqz 4 | from torchsignal.transform.fft import fast_fourier_transform 5 | 6 | 7 | def butter_bandpass(lowcut, highcut, sample_rate, order=4, output='ba'): 8 | r""" 9 | Create a Butterworth bandpass filter 10 | Design an Nth-order digital or analog Butterworth filter and return the filter coefficients. 11 | 12 | Args: 13 | lowcut : int 14 | Lower bound filter 15 | highcut : int 16 | Upper bound filter 17 | sample_rate : int 18 | Sampling frequency 19 | order : int, default: 4 20 | Order of the filter 21 | output : string, default: ba 22 | Type of output {‘ba’, ‘zpk’, ‘sos’} 23 | Returns: 24 | butter : ndarray 25 | Butterworth filter 26 | Dependencies: 27 | butter : scipy.signal.butter 28 | """ 29 | nyq = sample_rate * 0.5 30 | low = lowcut / nyq 31 | high = highcut / nyq 32 | return butter(order, [low, high], btype='bandpass', output=output) 33 | 34 | 35 | def butter_bandpass_filter_signal_1d(signal, lowcut, highcut, sample_rate, order, verbose=False): 36 | r""" 37 | Digital filter bandpass zero-phase implementation (filtfilt) 38 | Apply a digital filter forward and backward to a signal 39 | 40 | Args: 41 | signal : ndarray, shape (time,) 42 | Single input signal in time domain 43 | lowcut : int 44 | Lower bound filter 45 | highcut : int 46 | Upper bound filter 47 | sample_rate : int 48 | Sampling frequency 49 | order : int, default: 4 50 | Order of the filter 51 | verbose : boolean, default: False 52 | Print and plot details 53 | Returns: 54 | y : ndarray 55 | Filter signal 56 | Dependencies: 57 | filtfilt : scipy.signal.filtfilt 58 | butter_bandpass : function 59 | plt : `matplotlib.pyplot` package 60 | freqz : scipy.signal.freqz 61 | fast_fourier_transform : function 62 | """ 63 | b, a = butter_bandpass(lowcut, highcut, sample_rate, order) 64 | y = filtfilt(b, a, signal) 65 | 66 | if verbose: 67 | w, h = freqz(b, a) 68 | plt.plot((sample_rate * 0.5 / np.pi) * w, 69 | abs(h), label="order = %d" % order) 70 | plt.plot([0, 0.5 * sample_rate], [np.sqrt(0.5), np.sqrt(0.5)], 71 | '--', label='sqrt(0.5)') 72 | plt.xlabel('Frequency (Hz)') 73 | plt.ylabel('Gain') 74 | plt.grid(True) 75 | plt.legend(loc='best') 76 | low = max(0, lowcut-(sample_rate/100)) 77 | high = highcut+(sample_rate/100) 78 | plt.xlim([low, high]) 79 | plt.ylim([0, 1.2]) 80 | plt.title('Frequency response of filter - lowcut:' + 81 | str(lowcut)+', highcut:'+str(highcut)) 82 | plt.show() 83 | 84 | # TIME 85 | plt.plot(signal, label='Signal') 86 | plt.title('Signal') 87 | plt.show() 88 | 89 | plt.plot(y, label='Filtered') 90 | plt.title('Bandpass filtered') 91 | plt.show() 92 | 93 | # FREQ 94 | lower_xlim = lowcut-10 if (lowcut-10) > 0 else 0 95 | fast_fourier_transform( 96 | signal, sample_rate, plot=True, plot_xlim=[lower_xlim, highcut+20], plot_label='Signal') 97 | fast_fourier_transform( 98 | y, sample_rate, plot=True, plot_xlim=[lower_xlim, highcut+20], plot_label='Filtered') 99 | 100 | plt.xlim([lower_xlim, highcut+20]) 101 | plt.ylim([0, 2]) 102 | plt.legend() 103 | plt.xlabel('Frequency (Hz)') 104 | plt.show() 105 | 106 | print('Input: Signal shape', signal.shape) 107 | print('Output: Signal shape', y.shape) 108 | return y 109 | 110 | 111 | def butter_bandpass_filter(signal, lowcut, highcut, sample_rate, order, verbose=False): 112 | r""" 113 | Digital filter bandpass zero-phase implementation (filtfilt) 114 | Apply a digital filter forward and backward to a signal 115 | 116 | Dependencies: 117 | sosfiltfilt : scipy.signal.sosfiltfilt 118 | butter_bandpass : function 119 | fast_fourier_transform : function 120 | plt : `matplotlib.pyplot` package 121 | Args: 122 | signal : ndarray, shape (trial,channel,time) 123 | Input signal by trials in time domain 124 | lowcut : int 125 | Lower bound filter 126 | highcut : int 127 | Upper bound filter 128 | sample_rate : int 129 | Sampling frequency 130 | order : int, default: 4 131 | Order of the filter 132 | verbose : boolean, default: False 133 | Print and plot details 134 | Returns: 135 | y : ndarray 136 | Filter signal 137 | """ 138 | sos = butter_bandpass(lowcut, highcut, sample_rate, 139 | order=order, output='sos') 140 | y = sosfiltfilt(sos, signal, axis=2) 141 | 142 | if verbose: 143 | tmp_x = signal[0, 0] 144 | tmp_y = y[0, 0] 145 | 146 | # time domain 147 | plt.plot(tmp_x, label='signal') 148 | plt.show() 149 | 150 | plt.plot(tmp_y, label='Filtered') 151 | plt.show() 152 | 153 | # freq domain 154 | lower_xlim = lowcut-10 if (lowcut-10) > 0 else 0 155 | fast_fourier_transform( 156 | tmp_x, sample_rate, plot=True, plot_xlim=[lower_xlim, highcut+20], plot_label='Signal') 157 | fast_fourier_transform( 158 | tmp_y, sample_rate, plot=True, plot_xlim=[lower_xlim, highcut+20], plot_label='Filtered') 159 | 160 | plt.xlim([lower_xlim, highcut+20]) 161 | plt.ylim([0, 2]) 162 | plt.legend() 163 | plt.xlabel('Frequency (Hz)') 164 | plt.show() 165 | 166 | print('Input: Signal shape', signal.shape) 167 | print('Output: Signal shape', y.shape) 168 | 169 | return y 170 | -------------------------------------------------------------------------------- /torchsignal/filter/channels.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def pick_channels(data: np.ndarray, 5 | channel_names: [str], 6 | selected_channels: [str], 7 | verbose: bool = False) -> np.ndarray: 8 | 9 | picked_ch = pick_channels_mne(channel_names, selected_channels) 10 | data = data[:, picked_ch, :] 11 | 12 | if verbose: 13 | print('picking channels: channel_names', 14 | len(channel_names), channel_names) 15 | print('picked_ch', picked_ch) 16 | print() 17 | 18 | del picked_ch 19 | 20 | return data 21 | 22 | 23 | def pick_channels_mne(ch_names, include, exclude=[], ordered=False): 24 | """Pick channels by names. 25 | Returns the indices of ``ch_names`` in ``include`` but not in ``exclude``. 26 | Taken from https://github.com/mne-tools/mne-python/blob/master/mne/io/pick.py 27 | 28 | Parameters 29 | ---------- 30 | ch_names : list of str 31 | List of channels. 32 | include : list of str 33 | List of channels to include (if empty include all available). 34 | .. note:: This is to be treated as a set. The order of this list 35 | is not used or maintained in ``sel``. 36 | exclude : list of str 37 | List of channels to exclude (if empty do not exclude any channel). 38 | Defaults to []. 39 | ordered : bool 40 | If true (default False), treat ``include`` as an ordered list 41 | rather than a set, and any channels from ``include`` are missing 42 | in ``ch_names`` an error will be raised. 43 | .. versionadded:: 0.18 44 | Returns 45 | ------- 46 | sel : array of int 47 | Indices of good channels. 48 | See Also 49 | -------- 50 | pick_channels_regexp, pick_types 51 | """ 52 | if len(np.unique(ch_names)) != len(ch_names): 53 | raise RuntimeError('ch_names is not a unique list, picking is unsafe') 54 | # _check_excludes_includes(include) 55 | # _check_excludes_includes(exclude) 56 | if not ordered: 57 | if not isinstance(include, set): 58 | include = set(include) 59 | if not isinstance(exclude, set): 60 | exclude = set(exclude) 61 | sel = [] 62 | for k, name in enumerate(ch_names): 63 | if (len(include) == 0 or name in include) and name not in exclude: 64 | sel.append(k) 65 | else: 66 | if not isinstance(include, list): 67 | include = list(include) 68 | if len(include) == 0: 69 | include = list(ch_names) 70 | if not isinstance(exclude, list): 71 | exclude = list(exclude) 72 | sel, missing = list(), list() 73 | for name in include: 74 | if name in ch_names: 75 | if name not in exclude: 76 | sel.append(ch_names.index(name)) 77 | else: 78 | missing.append(name) 79 | if len(missing): 80 | raise ValueError('Missing channels from ch_names required by ' 81 | 'include:\n%s' % (missing,)) 82 | return np.array(sel, int) 83 | -------------------------------------------------------------------------------- /torchsignal/model/CompactEEGNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """EEGNet: Compact Convolutional Neural Network (Compact-CNN) https://arxiv.org/pdf/1803.04566.pdf 3 | """ 4 | import torch 5 | from torch import nn 6 | from .common.conv import SeparableConv2d 7 | 8 | 9 | class CompactEEGNet(nn.Module): 10 | """ 11 | EEGNet: Compact Convolutional Neural Network (Compact-CNN) 12 | https://arxiv.org/pdf/1803.04566.pdf 13 | """ 14 | def __init__(self, num_channel=10, num_classes=4, signal_length=1000, f1=96, f2=96, d=1): 15 | super().__init__() 16 | 17 | self.signal_length = signal_length 18 | 19 | # layer 1 20 | self.conv1 = nn.Conv2d(1, f1, (1, signal_length), padding=(0,signal_length//2)) 21 | self.bn1 = nn.BatchNorm2d(f1) 22 | self.depthwise_conv = nn.Conv2d(f1, d*f1, (num_channel, 1), groups=f1) 23 | self.bn2 = nn.BatchNorm2d(d*f1) 24 | self.avgpool1 = nn.AvgPool2d((1,4)) 25 | 26 | # layer 2 27 | self.separable_conv = SeparableConv2d( 28 | in_channels=f1, 29 | out_channels=f2, 30 | kernel_size=(1,16) 31 | ) 32 | self.bn3 = nn.BatchNorm2d(f2) 33 | self.avgpool2 = nn.AvgPool2d((1,8)) 34 | 35 | # layer 3 36 | self.linear = nn.Linear(in_features=f2*(signal_length//32), out_features=num_classes) 37 | 38 | self.dropout = nn.Dropout(p=0.5) 39 | self.elu = nn.ELU() 40 | 41 | def forward(self, x): 42 | 43 | # layer 1 44 | x = torch.unsqueeze(x,1) 45 | x = self.conv1(x) 46 | x = self.bn1(x) 47 | x = self.depthwise_conv(x) 48 | x = self.bn2(x) 49 | x = self.elu(x) 50 | x = self.avgpool1(x) 51 | x = self.dropout(x) 52 | 53 | # layer 2 54 | x = self.separable_conv(x) 55 | x = self.bn3(x) 56 | x = self.elu(x) 57 | x = self.avgpool2(x) 58 | x = self.dropout(x) 59 | 60 | # layer 3 61 | x = torch.flatten(x, start_dim=1) 62 | x = self.linear(x) 63 | 64 | return x 65 | -------------------------------------------------------------------------------- /torchsignal/model/MIEEGNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .common.conv import DepthwiseConv2d, SeparableConv2d, Conv2d 4 | 5 | 6 | class MIEEGNet(nn.Module): 7 | """ 8 | Mobile Inception EEGNet model 9 | MI-EEGNET: A novel Convolutional Neural Network for motor imagery classification 10 | https://www.sciencedirect.com/science/article/abs/pii/S016502702030460X 11 | 12 | Usage: 13 | from torchsignal.model import MIEEGNet 14 | model = MIEEGNet(num_channel=22, num_classes=12, signal_length=256) 15 | x = torch.randn(1, 22, 256) 16 | print("Input shape:", x.shape) # torch.Size([1, 22, 256]) 17 | y = model(x) 18 | print("Output shape:", y.shape) # torch.Size([1, 12]) 19 | 20 | Note: 21 | 1. In my implementation, I did not get the same number of parameters as stated in the paper. This model has 106304 params instead of 162564 as stated in the paper. 22 | 2. Somehow this kind of architecture only support `signal_length` with length of ^2 (128, 256, 512, etc), because of the 4 towers, the size must align. 23 | 3. Certain `num_classes` might have issues, because they specifically state to use global pooling at the end. Replacing it with a linear or a Conv2d can better manage the kernel sizes. 24 | """ 25 | def __init__(self, num_channel=22, num_classes=4, signal_length=256, depth=4, first_filter_size=64): 26 | super().__init__() 27 | 28 | self.num_classes = num_classes 29 | 30 | filter_size = [first_filter_size, first_filter_size*depth, first_filter_size*depth*4] 31 | 32 | self.conv1 = nn.Sequential( 33 | Conv2d(1, filter_size[0], kernel_size=(1, 16), padding="SAME"), 34 | nn.BatchNorm2d(filter_size[0]), 35 | DepthwiseConv2d(filter_size[0], filter_size[0]//depth, kernel_size=(num_channel, 1), depth=4, bias=True), 36 | nn.BatchNorm2d(filter_size[0]), 37 | nn.ELU(inplace=True), 38 | nn.AvgPool2d(kernel_size=(1,2)), 39 | nn.Dropout(p=0.5) 40 | ) 41 | 42 | self.convtower1 = nn.Sequential( 43 | Conv2d(filter_size[0], filter_size[0], kernel_size=(1, 1), padding="SAME"), 44 | SeparableConv2d(filter_size[0], filter_size[0], kernel_size=(1,7), bias=True), 45 | nn.BatchNorm2d(filter_size[0]), 46 | nn.ELU(inplace=True), 47 | nn.Dropout(p=0.5), 48 | SeparableConv2d(filter_size[0], filter_size[0], kernel_size=(1,7), bias=True), 49 | nn.AvgPool2d(kernel_size=(1,2)), 50 | ) 51 | 52 | self.convtower2 = nn.Sequential( 53 | Conv2d(filter_size[0], filter_size[0], kernel_size=(1, 1), padding="SAME"), 54 | SeparableConv2d(filter_size[0], filter_size[0], kernel_size=(1,9), bias=True), 55 | nn.BatchNorm2d(filter_size[0]), 56 | nn.ELU(inplace=True), 57 | nn.Dropout(p=0.5), 58 | SeparableConv2d(filter_size[0], filter_size[0], kernel_size=(1,9), bias=True), 59 | nn.AvgPool2d(kernel_size=(1,2)), 60 | ) 61 | 62 | self.convtower3 = nn.Sequential( 63 | nn.AvgPool2d(kernel_size=(1,2)), 64 | Conv2d(filter_size[0], filter_size[0], kernel_size=(1, 1), padding="SAME"), 65 | ) 66 | 67 | self.convtower4 = nn.Sequential( 68 | Conv2d(filter_size[0], filter_size[0], kernel_size=(1, 1), stride=(1,2), padding="SAME"), 69 | ) 70 | 71 | self.conv2 = nn.Sequential( 72 | nn.BatchNorm2d(filter_size[1]), 73 | nn.ELU(inplace=True), 74 | SeparableConv2d(filter_size[1], filter_size[1], kernel_size=(1,5), bias=True), 75 | nn.BatchNorm2d(filter_size[1]), 76 | nn.ELU(inplace=True), 77 | nn.Dropout(p=0.5), 78 | nn.AvgPool3d(kernel_size=(filter_size[1]//self.num_classes, 1, signal_length//4)) 79 | ) 80 | 81 | def forward(self, x): 82 | x = torch.unsqueeze(x,1) 83 | x = self.conv1(x) 84 | 85 | x1 = self.convtower1(x) 86 | x2 = self.convtower2(x) 87 | x3 = self.convtower3(x) 88 | x4 = self.convtower4(x) 89 | x = torch.cat([x1, x2, x3, x4], dim=1) 90 | 91 | x = self.conv2(x) 92 | x = x.view(x.size()[0],-1) 93 | 94 | return x 95 | -------------------------------------------------------------------------------- /torchsignal/model/MultitaskSSVEP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .common.conv import Conv2dBlockELU 4 | 5 | 6 | class MultitaskSSVEP(nn.Module): 7 | """ 8 | Using multi-task learning to capture signals simultaneously from the fovea efficiently and the neighboring targets in the peripheral vision generate a visual response map. A calibration-free user-independent solution, desirable for clinical diagnostics. A stepping stone for an objective assessment of glaucoma patients’ visual field. 9 | Learn more about this model at https://jinglescode.github.io/ssvep-multi-task-learning/ 10 | This model is a multi-label model. Although it produces multiple outputs, we also used this model to get our multi-class results in our paper. 11 | 12 | Usage: 13 | model = MultitaskSSVEP( 14 | num_channel=11, 15 | num_classes=40, 16 | signal_length=250, 17 | ) 18 | 19 | x = torch.randn(2, 11, 250) 20 | print("Input shape:", x.shape) # torch.Size([2, 11, 250]) 21 | y = model(x) 22 | print("Output shape:", y.shape) # torch.Size([2, 40, 2]) 23 | 24 | Cite: 25 | @inproceedings{khok2020deep, 26 | title={Deep Multi-Task Learning for SSVEP Detection and Visual Response Mapping}, 27 | author={Khok, Hong Jing and Koh, Victor Teck Chang and Guan, Cuntai}, 28 | booktitle={2020 IEEE International Conference on Systems, Man, and Cybernetics (SMC)}, 29 | pages={1280--1285}, 30 | year={2020}, 31 | organization={IEEE} 32 | } 33 | """ 34 | 35 | def __init__(self, num_channel=10, num_classes=4, signal_length=1000, filters_n1=4, kernel_window_ssvep=59, kernel_window=19, conv_3_dilation=4, conv_4_dilation=4): 36 | super().__init__() 37 | 38 | filters = [filters_n1, filters_n1 * 2] 39 | 40 | self.conv_1 = Conv2dBlockELU(in_channels=1, out_channels=filters[0], kernel_size=(1, kernel_window_ssvep), w_in=signal_length) 41 | self.conv_2 = Conv2dBlockELU(in_channels=filters[0], out_channels=filters[0], kernel_size=(num_channel, 1)) 42 | self.conv_3 = Conv2dBlockELU(in_channels=filters[0], out_channels=filters[1], kernel_size=(1, kernel_window), padding=(0,conv_3_dilation-1), dilation=(1,conv_3_dilation), w_in=self.conv_1.w_out) 43 | self.conv_4 = Conv2dBlockELU(in_channels=filters[1], out_channels=filters[1], kernel_size=(1, kernel_window), padding=(0,conv_4_dilation-1), dilation=(1,conv_4_dilation), w_in=self.conv_3.w_out) 44 | self.conv_mtl = multitask_block(filters[1]*num_classes, num_classes, kernel_size=(1, self.conv_4.w_out)) 45 | 46 | self.dropout = nn.Dropout(p=0.5) 47 | 48 | def forward(self, x): 49 | x = torch.unsqueeze(x,1) 50 | 51 | x = self.conv_1(x) 52 | x = self.conv_2(x) 53 | x = self.dropout(x) 54 | 55 | x = self.conv_3(x) 56 | x = self.conv_4(x) 57 | x = self.dropout(x) 58 | 59 | x = self.conv_mtl(x) 60 | return x 61 | 62 | 63 | class multitask_block(nn.Module): 64 | def __init__(self, in_ch, num_classes, kernel_size): 65 | super(multitask_block, self).__init__() 66 | self.num_classes = num_classes 67 | self.conv_mtl = nn.Conv2d(in_ch, num_classes*2, kernel_size=kernel_size, groups=num_classes) 68 | 69 | def forward(self, x): 70 | x = torch.cat(self.num_classes*[x], 1) 71 | x = self.conv_mtl(x) 72 | x = x.view(-1, self.num_classes, 2) 73 | return x 74 | -------------------------------------------------------------------------------- /torchsignal/model/MultitaskSSVEPClassifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .common.conv import Conv2dBlockELU 4 | 5 | 6 | class MultitaskSSVEPClassifier(nn.Module): 7 | """ 8 | Using multi-task learning to capture signals simultaneously from the fovea efficiently and the neighboring targets in the peripheral vision generate a visual response map. A calibration-free user-independent solution, desirable for clinical diagnostics. A stepping stone for an objective assessment of glaucoma patients’ visual field. 9 | Learn more about this model at https://jinglescode.github.io/ssvep-multi-task-learning/ 10 | This model is a multi-label model. Although it produces multiple outputs, we also used this model to get our multi-class results in our paper. 11 | 12 | Usage: 13 | model = MultitaskSSVEPClassifier( 14 | num_channel=11, 15 | num_classes=40, 16 | signal_length=250, 17 | ) 18 | 19 | x = torch.randn(2, 11, 250) 20 | print("Input shape:", x.shape) # torch.Size([2, 11, 250]) 21 | y = model(x) 22 | print("Output shape:", y.shape) # torch.Size([2, 40]) 23 | 24 | Cite: 25 | @inproceedings{khok2020deep, 26 | title={Deep Multi-Task Learning for SSVEP Detection and Visual Response Mapping}, 27 | author={Khok, Hong Jing and Koh, Victor Teck Chang and Guan, Cuntai}, 28 | booktitle={2020 IEEE International Conference on Systems, Man, and Cybernetics (SMC)}, 29 | pages={1280--1285}, 30 | year={2020}, 31 | organization={IEEE} 32 | } 33 | """ 34 | 35 | def __init__(self, num_channel=10, num_classes=4, signal_length=1000, filters_n1=4, kernel_window_ssvep=59, kernel_window=19, conv_3_dilation=4, conv_4_dilation=4): 36 | super().__init__() 37 | 38 | filters = [filters_n1, filters_n1 * 2] 39 | self.num_classes = num_classes 40 | 41 | self.conv_1 = Conv2dBlockELU(in_channels=1, out_channels=filters[0], kernel_size=(1, kernel_window_ssvep), w_in=signal_length) 42 | self.conv_2 = Conv2dBlockELU(in_channels=filters[0], out_channels=filters[0], kernel_size=(num_channel, 1)) 43 | self.conv_3 = Conv2dBlockELU(in_channels=filters[0], out_channels=filters[1], kernel_size=(1, kernel_window), padding=(0,conv_3_dilation-1), dilation=(1,conv_3_dilation), w_in=self.conv_1.w_out) 44 | self.conv_4 = Conv2dBlockELU(in_channels=filters[1], out_channels=filters[1], kernel_size=(1, kernel_window), padding=(0,conv_4_dilation-1), dilation=(1,conv_4_dilation), w_in=self.conv_3.w_out) 45 | self.conv_mtl = multitask_block(filters[1]*num_classes, num_classes, kernel_size=(1, self.conv_4.w_out)) 46 | self.classify = nn.Conv1d(num_classes, num_classes, kernel_size=2) 47 | self.dropout = nn.Dropout(p=0.5) 48 | 49 | def forward(self, x): 50 | x = torch.unsqueeze(x,1) 51 | 52 | x = self.conv_1(x) 53 | x = self.conv_2(x) 54 | x = self.dropout(x) 55 | 56 | x = self.conv_3(x) 57 | x = self.conv_4(x) 58 | x = self.dropout(x) 59 | 60 | x = self.conv_mtl(x) 61 | x = self.classify(x) 62 | x = x.view(-1, self.num_classes) 63 | return x 64 | 65 | 66 | class multitask_block(nn.Module): 67 | def __init__(self, in_ch, num_classes, kernel_size): 68 | super(multitask_block, self).__init__() 69 | self.num_classes = num_classes 70 | self.conv_mtl = nn.Conv2d(in_ch, num_classes*2, kernel_size=kernel_size, groups=num_classes) 71 | 72 | def forward(self, x): 73 | x = torch.cat(self.num_classes*[x], 1) 74 | x = self.conv_mtl(x) 75 | x = x.view(-1, self.num_classes, 2) 76 | return x 77 | -------------------------------------------------------------------------------- /torchsignal/model/Performer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Attention with Performers 3 | https://arxiv.org/abs/2009.14794 4 | """ 5 | # https://github.com/lucidrains/performer-pytorch/blob/main/performer_pytorch/performer_pytorch.py 6 | 7 | import math 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn 11 | from einops import rearrange, repeat 12 | from functools import partial 13 | 14 | ## 15 | 16 | # https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/local_attention.py 17 | 18 | # class LocalAttention(nn.Module): 19 | # """Implement fast local attention where a query can only attend to 20 | # neighboring keys. 21 | # In this attention module the query Q_i can only attend to a key K_j if 22 | # |i-j| < local_context/2. 23 | # Arguments 24 | # --------- 25 | # local_context: The neighborhood to consider for local attention. 26 | # softmax_temp: The temperature to use for the softmax attention. 27 | # (default: 1/sqrt(d_keys) where d_keys is computed at 28 | # runtime) 29 | # attention_dropout: The dropout rate to apply to the attention 30 | # (default: 0.1) 31 | # event_dispatcher: str or EventDispatcher instance to be used by this 32 | # module for dispatching events (default: the default 33 | # global dispatcher) 34 | # """ 35 | # def __init__(self, local_context, softmax_temp=None, attention_dropout=0.1, 36 | # event_dispatcher=""): 37 | # super(LocalAttention, self).__init__() 38 | # self.local_context = local_context 39 | # self.softmax_temp = softmax_temp 40 | # self.dropout = Dropout(attention_dropout) 41 | # self.event_dispatcher = EventDispatcher.get(event_dispatcher) 42 | 43 | # def forward(self, queries, keys, values, attn_mask, query_lengths, 44 | # key_lengths): 45 | # """Implements the local attention. 46 | # The attn_mask can be anything but the only values that will be 47 | # considered will be the ones in the neighborhood of each query. 48 | # Arguments 49 | # --------- 50 | # queries: (N, L, H, E) The tensor containing the queries 51 | # keys: (N, S, H, E) The tensor containing the keys 52 | # values: (N, S, H, D) The tensor containing the values 53 | # attn_mask: An implementation of BaseMask that encodes where each 54 | # query can attend to 55 | # query_lengths: An implementation of BaseMask that encodes how 56 | # many queries each sequence in the batch consists of 57 | # key_lengths: An implementation of BaseMask that encodes how 58 | # many queries each sequence in the batch consists of 59 | # """ 60 | # # Extract some shapes and compute the temperature 61 | # N, L, H, E = queries.shape 62 | # _, S, _, D = values.shape 63 | # context = self.local_context 64 | # softmax_temp = self.softmax_temp or 1./sqrt(E) 65 | 66 | # # Permute the dimensions to NHLE instead of NLHE 67 | # queries = queries.permute(0, 2, 1, 3).contiguous() 68 | # keys = keys.permute(0, 2, 1, 3).contiguous() 69 | # values = values.permute(0, 2, 1, 3).contiguous() 70 | 71 | # QK = local_dot_product( 72 | # queries, 73 | # keys, 74 | # attn_mask.additive_matrix_finite, 75 | # key_lengths.lengths, 76 | # self.local_context 77 | # ) 78 | # A = self.dropout(torch.softmax(softmax_temp * QK, dim=-1)) 79 | 80 | # V_new = local_weighted_average(A, values) 81 | 82 | # return V_new.permute(0, 2, 1, 3).contiguous() 83 | 84 | ## 85 | 86 | import torch 87 | import torch.nn as nn 88 | from operator import itemgetter 89 | from torch.autograd.function import Function 90 | from torch.utils.checkpoint import get_device_states, set_device_states 91 | 92 | # for routing arguments into the functions of the reversible layer 93 | def route_args(router, args, depth): 94 | routed_args = [(dict(), dict()) for _ in range(depth)] 95 | matched_keys = [key for key in args.keys() if key in router] 96 | 97 | for key in matched_keys: 98 | val = args[key] 99 | for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): 100 | new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) 101 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) 102 | return routed_args 103 | 104 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 105 | class Deterministic(nn.Module): 106 | def __init__(self, net): 107 | super().__init__() 108 | self.net = net 109 | self.cpu_state = None 110 | self.cuda_in_fwd = None 111 | self.gpu_devices = None 112 | self.gpu_states = None 113 | 114 | def record_rng(self, *args): 115 | self.cpu_state = torch.get_rng_state() 116 | if torch.cuda._initialized: 117 | self.cuda_in_fwd = True 118 | self.gpu_devices, self.gpu_states = get_device_states(*args) 119 | 120 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs): 121 | if record_rng: 122 | self.record_rng(*args) 123 | 124 | if not set_rng: 125 | return self.net(*args, **kwargs) 126 | 127 | rng_devices = [] 128 | if self.cuda_in_fwd: 129 | rng_devices = self.gpu_devices 130 | 131 | with torch.random.fork_rng(devices=rng_devices, enabled=True): 132 | torch.set_rng_state(self.cpu_state) 133 | if self.cuda_in_fwd: 134 | set_device_states(self.gpu_devices, self.gpu_states) 135 | return self.net(*args, **kwargs) 136 | 137 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 138 | # once multi-GPU is confirmed working, refactor and send PR back to source 139 | class ReversibleBlock(nn.Module): 140 | def __init__(self, f, g): 141 | super().__init__() 142 | self.f = Deterministic(f) 143 | self.g = Deterministic(g) 144 | 145 | def forward(self, x, f_args = {}, g_args = {}): 146 | x1, x2 = torch.chunk(x, 2, dim=2) 147 | y1, y2 = None, None 148 | 149 | with torch.no_grad(): 150 | y1 = x1 + self.f(x2, record_rng=self.training, **f_args) 151 | y2 = x2 + self.g(y1, record_rng=self.training, **g_args) 152 | 153 | return torch.cat([y1, y2], dim=2) 154 | 155 | def backward_pass(self, y, dy, f_args = {}, g_args = {}): 156 | y1, y2 = torch.chunk(y, 2, dim=2) 157 | del y 158 | 159 | dy1, dy2 = torch.chunk(dy, 2, dim=2) 160 | del dy 161 | 162 | with torch.enable_grad(): 163 | y1.requires_grad = True 164 | gy1 = self.g(y1, set_rng=True, **g_args) 165 | torch.autograd.backward(gy1, dy2) 166 | 167 | with torch.no_grad(): 168 | x2 = y2 - gy1 169 | del y2, gy1 170 | 171 | dx1 = dy1 + y1.grad 172 | del dy1 173 | y1.grad = None 174 | 175 | with torch.enable_grad(): 176 | x2.requires_grad = True 177 | fx2 = self.f(x2, set_rng=True, **f_args) 178 | torch.autograd.backward(fx2, dx1, retain_graph=True) 179 | 180 | with torch.no_grad(): 181 | x1 = y1 - fx2 182 | del y1, fx2 183 | 184 | dx2 = dy2 + x2.grad 185 | del dy2 186 | x2.grad = None 187 | 188 | x = torch.cat([x1, x2.detach()], dim=2) 189 | dx = torch.cat([dx1, dx2], dim=2) 190 | 191 | return x, dx 192 | 193 | class _ReversibleFunction(Function): 194 | @staticmethod 195 | def forward(ctx, x, blocks, args): 196 | ctx.args = args 197 | for block, kwarg in zip(blocks, args): 198 | x = block(x, **kwarg) 199 | ctx.y = x.detach() 200 | ctx.blocks = blocks 201 | return x 202 | 203 | @staticmethod 204 | def backward(ctx, dy): 205 | y = ctx.y 206 | args = ctx.args 207 | for block, kwargs in zip(ctx.blocks[::-1], args[::-1]): 208 | y, dy = block.backward_pass(y, dy, **kwargs) 209 | return dy, None, None 210 | 211 | class SequentialSequence(nn.Module): 212 | def __init__(self, layers, args_route = {}): 213 | super().__init__() 214 | assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' 215 | self.layers = layers 216 | self.args_route = args_route 217 | 218 | def forward(self, x, **kwargs): 219 | args = route_args(self.args_route, kwargs, len(self.layers)) 220 | layers_and_args = list(zip(self.layers, args)) 221 | 222 | for (f, g), (f_args, g_args) in layers_and_args: 223 | x = x + f(x, **f_args) 224 | x = x + g(x, **g_args) 225 | return x 226 | 227 | class ReversibleSequence(nn.Module): 228 | def __init__(self, blocks, args_route = {}): 229 | super().__init__() 230 | self.args_route = args_route 231 | self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks]) 232 | 233 | def forward(self, x, **kwargs): 234 | x = torch.cat([x, x], dim=-1) 235 | 236 | blocks = self.blocks 237 | args = route_args(self.args_route, kwargs, len(blocks)) 238 | args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args)) 239 | 240 | out = _ReversibleFunction.apply(x, blocks, args) 241 | return torch.stack(out.chunk(2, dim=-1)).sum(dim=0) 242 | 243 | ## 244 | 245 | # helpers 246 | 247 | def exists(val): 248 | return val is not None 249 | 250 | def empty(tensor): 251 | return tensor.numel() == 0 252 | 253 | def default(val, d): 254 | return val if exists(val) else d 255 | 256 | def cast_tuple(val): 257 | return (val,) if not isinstance(val, tuple) else val 258 | 259 | def get_module_device(module): 260 | return next(module.parameters()).device 261 | 262 | def find_modules(nn_module, type): 263 | return [module for module in nn_module.modules() if isinstance(module, type)] 264 | 265 | # kernel functions 266 | 267 | # transcribed from jax to pytorch from 268 | # https://github.com/google-research/google-research/blob/master/performer/fast_self_attention/fast_self_attention.py 269 | 270 | def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None): 271 | b, h, *_ = data.shape 272 | 273 | data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. 274 | 275 | ratio = (projection_matrix.shape[0] ** -0.5) 276 | 277 | projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) 278 | 279 | data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) 280 | 281 | diag_data = data ** 2 282 | diag_data = torch.sum(diag_data, dim=-1) 283 | diag_data = (diag_data / 2.0) * (data_normalizer ** 2) 284 | diag_data = diag_data.unsqueeze(dim=-1) 285 | 286 | if is_query: 287 | data_dash = ratio * ( 288 | torch.exp(data_dash - diag_data - 289 | torch.max(data_dash, dim=-1, keepdim=True).values) + eps) 290 | else: 291 | data_dash = ratio * ( 292 | torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps) 293 | 294 | return data_dash 295 | 296 | def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, normalize_data = True, device = None): 297 | b, h, *_ = data.shape 298 | 299 | data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. 300 | 301 | if projection_matrix is None: 302 | return kernel_fn(data_normalizer * data) + kernel_epsilon 303 | 304 | projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h) 305 | 306 | data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) 307 | 308 | data_prime = kernel_fn(data_dash) + kernel_epsilon 309 | return data_prime 310 | 311 | def orthogonal_matrix_chunk(cols, qr_uniform_q = False, device = None): 312 | unstructured_block = torch.randn((cols, cols), device = device) 313 | q, r = torch.qr(unstructured_block.cpu(), some = True) 314 | q, r = map(lambda t: t.to(device), (q, r)) 315 | 316 | # proposed by @Parskatt 317 | # to make sure Q is uniform https://arxiv.org/pdf/math-ph/0609050.pdf 318 | if qr_uniform_q: 319 | d = torch.diag(r, 0) 320 | q *= d.sign() 321 | return q.t() 322 | 323 | def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, qr_uniform_q = False, device = None): 324 | nb_full_blocks = int(nb_rows / nb_columns) 325 | 326 | block_list = [] 327 | 328 | for _ in range(nb_full_blocks): 329 | q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device) 330 | block_list.append(q) 331 | 332 | remaining_rows = nb_rows - nb_full_blocks * nb_columns 333 | if remaining_rows > 0: 334 | q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device) 335 | block_list.append(q[:remaining_rows]) 336 | 337 | final_matrix = torch.cat(block_list) 338 | 339 | if scaling == 0: 340 | multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1) 341 | elif scaling == 1: 342 | multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device) 343 | else: 344 | raise ValueError(f'Invalid scaling {scaling}') 345 | 346 | return torch.diag(multiplier) @ final_matrix 347 | 348 | # linear attention classes with softmax kernel 349 | 350 | # non-causal linear attention 351 | def linear_attention(q, k, v): 352 | D_inv = 1. / torch.einsum('...nd,...d->...n', q, k.sum(dim = -2)) 353 | context = torch.einsum('...nd,...ne->...de', k, v) 354 | out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv) 355 | return out 356 | 357 | # efficient causal linear attention, created by EPFL 358 | # TODO: rewrite EPFL's CUDA kernel to do mixed precision and remove half to float conversion and back 359 | def causal_linear_attention(q, k, v, amp_enabled = False): 360 | from fast_transformers.causal_product import CausalDotProduct 361 | is_half = isinstance(q, torch.cuda.HalfTensor) or amp_enabled 362 | 363 | if is_half: 364 | q, k, v = map(lambda t: t.float(), (q, k, v)) 365 | 366 | D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k.cumsum(dim=-2)) 367 | out = CausalDotProduct.apply(q, k, v) 368 | out = torch.einsum('...nd,...n->...nd', out, D_inv) 369 | 370 | if is_half: 371 | out = out.half() 372 | 373 | return out 374 | 375 | # inefficient causal linear attention, without cuda code, for reader's reference 376 | # not being used 377 | def causal_linear_attention_noncuda(q, k, v): 378 | D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k.cumsum(dim=-2)) 379 | context = torch.einsum('...nd,...ne->...nde', k, v) 380 | context = context.cumsum(dim=-3) 381 | out = torch.einsum('...nde,...nd,...n->...ne', context, q, D_inv) 382 | return out 383 | 384 | class FastAttention(nn.Module): 385 | def __init__(self, dim_heads, nb_features = None, feature_redraw_interval = 0, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, amp_enabled = False): 386 | super().__init__() 387 | nb_features = default(nb_features, int(dim_heads * math.log(dim_heads))) 388 | 389 | self.dim_heads = dim_heads 390 | self.nb_features = nb_features 391 | self.ortho_scaling = ortho_scaling 392 | 393 | self.feature_redraw_interval = feature_redraw_interval 394 | self.register_buffer('calls_since_last_redraw', torch.tensor(0)) # Make sure this is persistent 395 | 396 | self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling, qr_uniform_q = qr_uniform_q) 397 | projection_matrix = self.create_projection() 398 | self.register_buffer('projection_matrix', projection_matrix) 399 | 400 | self.generalized_attention = generalized_attention 401 | self.kernel_fn = kernel_fn 402 | 403 | self.causal = causal 404 | if causal: 405 | try: 406 | import fast_transformers.causal_product.causal_product_cuda 407 | self.causal_linear_fn = partial(causal_linear_attention, amp_enabled = amp_enabled) 408 | except ImportError: 409 | # print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version') 410 | self.causal_linear_fn = causal_linear_attention_noncuda 411 | 412 | def forward(self, q, k, v): 413 | device = q.device 414 | 415 | # It's time to redraw the projection matrix 416 | if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval: 417 | self.projection_matrix = self.create_projection(device = device).type_as(q) 418 | self.calls_since_last_redraw = torch.tensor(0) 419 | # Keep track of how many forward passes we do before we redraw again 420 | else: 421 | self.calls_since_last_redraw += 1 422 | 423 | if self.generalized_attention: 424 | create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device) 425 | q, k = map(create_kernel, (q, k)) 426 | else: 427 | create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device) 428 | q = create_kernel(q, is_query = True) 429 | k = create_kernel(k, is_query = False) 430 | 431 | attn_fn = linear_attention if not self.causal else self.causal_linear_fn 432 | out = attn_fn(q, k, v) 433 | return out 434 | 435 | # classes 436 | 437 | class ReZero(nn.Module): 438 | def __init__(self, fn): 439 | super().__init__() 440 | self.g = nn.Parameter(torch.tensor(1e-3)) 441 | self.fn = fn 442 | 443 | def forward(self, x, **kwargs): 444 | return self.fn(x, **kwargs) * self.g 445 | 446 | class PreScaleNorm(nn.Module): 447 | def __init__(self, dim, fn, eps=1e-5): 448 | super().__init__() 449 | self.fn = fn 450 | self.g = nn.Parameter(torch.ones(1)) 451 | self.eps = eps 452 | 453 | def forward(self, x, **kwargs): 454 | n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps) 455 | x = x / n * self.g 456 | return self.fn(x, **kwargs) 457 | 458 | class PreLayerNorm(nn.Module): 459 | def __init__(self, dim, fn): 460 | super().__init__() 461 | self.norm = nn.LayerNorm(dim) 462 | self.fn = fn 463 | def forward(self, x, **kwargs): 464 | return self.fn(self.norm(x), **kwargs) 465 | 466 | class Chunk(nn.Module): 467 | def __init__(self, chunks, fn, along_dim = -1): 468 | super().__init__() 469 | self.dim = along_dim 470 | self.chunks = chunks 471 | self.fn = fn 472 | 473 | def forward(self, x, **kwargs): 474 | if self.chunks == 1: 475 | return self.fn(x, **kwargs) 476 | chunks = x.chunk(self.chunks, dim = self.dim) 477 | return torch.cat([self.fn(c, **kwargs) for c in chunks], dim = self.dim) 478 | 479 | class FeedForward(nn.Module): 480 | def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False): 481 | super().__init__() 482 | activation = default(activation, nn.GELU) 483 | 484 | self.glu = glu 485 | self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1)) 486 | self.act = activation() 487 | self.dropout = nn.Dropout(dropout) 488 | self.w2 = nn.Linear(dim * mult, dim) 489 | 490 | def forward(self, x, **kwargs): 491 | if not self.glu: 492 | x = self.w1(x) 493 | x = self.act(x) 494 | else: 495 | x, v = self.w1(x).chunk(2, dim=-1) 496 | x = self.act(x) * v 497 | 498 | x = self.dropout(x) 499 | x = self.w2(x) 500 | return x 501 | 502 | class SelfAttention(nn.Module): 503 | def __init__(self, dim, causal = False, heads = 8, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, dropout = 0., amp_enabled = False): 504 | super().__init__() 505 | assert dim % heads == 0, 'dimension must be divisible by number of heads' 506 | dim_head = dim // heads 507 | self.fast_attention = FastAttention(dim_head, nb_features, feature_redraw_interval, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, amp_enabled = amp_enabled) 508 | 509 | self.heads = heads 510 | self.global_heads = heads - local_heads 511 | self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None 512 | 513 | self.to_q = nn.Linear(dim, dim) 514 | self.to_k = nn.Linear(dim, dim) 515 | self.to_v = nn.Linear(dim, dim) 516 | self.to_out = nn.Linear(dim, dim) 517 | self.dropout = nn.Dropout(dropout) 518 | 519 | def forward(self, x, context = None, mask = None, context_mask = None): 520 | b, n, _, h, gh = *x.shape, self.heads, self.global_heads 521 | 522 | cross_attend = exists(context) 523 | context = default(context, x) 524 | context_mask = default(context_mask, mask) 525 | 526 | q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) 527 | 528 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 529 | (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) 530 | 531 | attn_outs = [] 532 | 533 | if not empty(q): 534 | if exists(context_mask): 535 | global_mask = context_mask[:, None, :, None] 536 | k.masked_fill_(~global_mask, 0) 537 | 538 | out = self.fast_attention(q, k, v) 539 | attn_outs.append(out) 540 | 541 | if not empty(lq): 542 | assert 'local attention is not compatible with cross attention' 543 | out = self.local_attn(lq, lk, lv, input_mask = mask) 544 | attn_outs.append(out) 545 | 546 | out = torch.cat(attn_outs, dim = 1) 547 | out = rearrange(out, 'b h n d -> b n (h d)') 548 | out = self.to_out(out) 549 | return self.dropout(out) 550 | 551 | class Performer(nn.Module): 552 | def __init__(self, dim, depth, heads, local_attn_heads = 0, local_window_size = 256, causal = False, ff_mult = 4, nb_features = None, reversible = False, ff_chunks = 1, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, use_scalenorm = False, use_rezero = False, ff_glu = False, ff_dropout = 0., attn_dropout = 0., cross_attend = False, amp_enabled = False): 553 | super().__init__() 554 | layers = nn.ModuleList([]) 555 | local_attn_heads = cast_tuple(local_attn_heads) 556 | local_attn_heads = local_attn_heads * depth if len(local_attn_heads) == 1 else local_attn_heads 557 | assert len(local_attn_heads) == depth, 'tuple specifying number of local attention heads per depth must be equal to the total depth' 558 | assert all(map(lambda n: n >= 0 and n <= heads, local_attn_heads)), 'local attention head value must be less than the total number of heads' 559 | 560 | if use_scalenorm: 561 | wrapper_fn = partial(PreScaleNorm, dim) 562 | elif use_rezero: 563 | wrapper_fn = ReZero 564 | else: 565 | wrapper_fn = partial(PreLayerNorm, dim) 566 | 567 | for _, local_heads in zip(range(depth), local_attn_heads): 568 | layers.append(nn.ModuleList([ 569 | wrapper_fn(SelfAttention(dim, causal = causal, heads = heads, local_heads = local_heads, local_window_size = local_window_size, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, dropout = attn_dropout, amp_enabled = amp_enabled)), 570 | wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1)) 571 | ])) 572 | 573 | if not cross_attend: 574 | continue 575 | 576 | layers.append(nn.ModuleList([ 577 | wrapper_fn(SelfAttention(dim, heads = heads, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, dropout = attn_dropout)), 578 | wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1)) 579 | ])) 580 | 581 | execute_type = ReversibleSequence if reversible else SequentialSequence 582 | 583 | route_attn = ((True, False),) * depth * (2 if cross_attend else 1) 584 | route_context = ((False, False), (True, False)) * depth 585 | attn_route_map = {'mask': route_attn} 586 | context_route_map = {'context': route_context, 'context_mask': route_context} if cross_attend else {} 587 | self.net = execute_type(layers, args_route = {**attn_route_map, **context_route_map}) 588 | 589 | def forward(self, x, **kwargs): 590 | return self.net(x, **kwargs) 591 | 592 | class PerformerLM(nn.Module): 593 | def __init__(self, *, num_tokens, max_seq_len, dim, depth, heads, local_attn_heads = 0, local_window_size = 256, causal = False, ff_mult = 4, nb_features = None, reversible = False, ff_chunks = 1, ff_glu = False, emb_dropout = 0., ff_dropout = 0., attn_dropout = 0., generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, use_scalenorm = False, use_rezero = False, cross_attend = False, amp_enabled = False): 594 | super().__init__() 595 | local_attn_heads = cast_tuple(local_attn_heads) 596 | 597 | self.max_seq_len = max_seq_len 598 | self.token_emb = nn.Embedding(num_tokens, dim) 599 | self.pos_emb = nn.Embedding(max_seq_len, dim) 600 | self.dropout = nn.Dropout(emb_dropout) 601 | 602 | nn.init.normal_(self.token_emb.weight, std = 0.02) 603 | nn.init.normal_(self.pos_emb.weight, std = 0.02) 604 | 605 | self.performer = Performer(dim, depth, heads, local_attn_heads, local_window_size, causal, ff_mult, nb_features, reversible, ff_chunks, generalized_attention, kernel_fn, qr_uniform_q, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, amp_enabled) 606 | self.norm = nn.LayerNorm(dim) 607 | 608 | def fix_projection_matrices_(self): 609 | fast_attentions = find_modules(self, FastAttention) 610 | device = get_module_device(self) 611 | for fast_attention in fast_attentions: 612 | fast_attention.feature_redraw_interval = None 613 | 614 | def forward(self, x, return_encodings = False, **kwargs): 615 | b, n, device = *x.shape, x.device 616 | # token and positional embeddings 617 | x = self.token_emb(x) 618 | x += self.pos_emb(torch.arange(n, device = device)) 619 | x = self.dropout(x) 620 | 621 | # performer layers 622 | x = self.performer(x, **kwargs) 623 | 624 | # norm and to logits 625 | x = self.norm(x) 626 | 627 | if return_encodings: 628 | return x 629 | 630 | return x @ self.token_emb.weight.t() -------------------------------------------------------------------------------- /torchsignal/model/README.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | ## Multitask Model 4 | 5 | Using multi-task learning to capture signals simultaneously from the fovea efficiently and the neighboring targets in the peripheral vision generate a visual response map. A calibration-free user-independent solution, desirable for clinical diagnostics. A stepping stone for an objective assessment of glaucoma patients’ visual field. Learn more about this model at https://jinglescode.github.io/ssvep-multi-task-learning/. 6 | 7 | ``` 8 | from torchsignal.model import MultitaskSSVEP 9 | 10 | model = MultitaskSSVEP(num_channel=10, 11 | num_classes=40, 12 | signal_length=1000, 13 | filters_n1=4, 14 | kernel_window_ssvep=59, 15 | kernel_window=19, 16 | conv_3_dilation=4, 17 | conv_4_dilation=4, 18 | ) 19 | 20 | x = torch.ones((20, 10, 1000)) 21 | print("Input shape:", x.shape) 22 | y = model(x) 23 | print("Output shape:", y.shape) 24 | ``` 25 | 26 | ## EEGNet (Compact) 27 | 28 | EEGNet: Compact Convolutional Neural Network (Compact-CNN) https://arxiv.org/pdf/1803.04566.pdf 29 | 30 | ``` 31 | from torchsignal.model import CompactEEGNet 32 | 33 | model = CompactEEGNet( 34 | num_channel=10, 35 | num_classes=4, 36 | signal_length=1000, 37 | ) 38 | 39 | x = torch.ones((21, 10, 1000)) 40 | print("Input shape:", x.shape) 41 | y = model(x) 42 | print("Output shape:", y.shape) 43 | ``` 44 | 45 | ## Performer 46 | 47 | Rethinking Attention with Performers 48 | https://arxiv.org/abs/2009.14794 49 | 50 | ``` 51 | from torchsignal.model import Performer 52 | 53 | model = Performer( 54 | dim = 11, 55 | depth = 1, 56 | heads = 1, 57 | causal = True 58 | ) 59 | 60 | x = torch.randn(1, 1000, 11) 61 | print("Input shape:", x.shape) # torch.Size([1, 1000, 11]) 62 | y = model(x) 63 | print("Output shape:", y.shape) # torch.Size([1, 1000, 11]) 64 | ``` 65 | 66 | ## WaveNet 67 | 68 | WaveNet: A Generative Model for Raw Audio 69 | https://arxiv.org/abs/1609.03499. 70 | 71 | ``` 72 | from torchsignal.model import WaveNet 73 | 74 | model = WaveNet( 75 | layers=6, 76 | blocks=3, 77 | dilation_channels=32, 78 | residual_channels=32, 79 | skip_channels=1024, 80 | classes=9, 81 | end_channels=512, 82 | bias=True 83 | ) 84 | 85 | x = torch.randn(2, 9, 250) 86 | print("Input shape:", x.shape) # torch.Size([2, 9, 250]) 87 | y = model(x) 88 | print("Output shape:", y.shape) # torch.Size([2, 9, 128]) 89 | ``` 90 | 91 | ## MI-EEGNet 92 | MI-EEGNET: A novel Convolutional Neural Network for motor imagery classification 93 | https://www.sciencedirect.com/science/article/abs/pii/S016502702030460X 94 | 95 | ``` 96 | from torchsignal.model import MIEEGNet 97 | 98 | model = MIEEGNet(num_channel=22, num_classes=12, signal_length=256) 99 | 100 | x = torch.randn(1, 22, 256) 101 | print("Input shape:", x.shape) # torch.Size([1, 22, 256]) 102 | y = model(x) 103 | print("Output shape:", y.shape) # torch.Size([1, 12]) 104 | ``` 105 | -------------------------------------------------------------------------------- /torchsignal/model/WaveNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is an implementation of the WaveNet architecture, as described in the original paper, WaveNet: A Generative Model for Raw Audio, https://arxiv.org/abs/1609.03499. 3 | Taken from https://github.com/vincentherrmann/pytorch-wavenet. 4 | Adapted for PyTorch 1.6. 5 | """ 6 | 7 | ################################### 8 | # https://github.com/vincentherrmann/pytorch-wavenet/blob/master/wavenet_modules.py 9 | 10 | import math 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.autograd import Variable, Function 16 | from torch.nn import Parameter 17 | 18 | 19 | def dilate(x, dilation, init_dilation=1, pad_start=True): 20 | """ 21 | :param x: Tensor of size (N, C, L), where N is the input dilation, C is the number of channels, and L is the input length 22 | :param dilation: Target dilation. Will be the size of the first dimension of the output tensor. 23 | :param pad_start: If the input length is not compatible with the specified dilation, zero padding is used. This parameter determines wether the zeros are added at the start or at the end. 24 | :return: The dilated tensor of size (dilation, C, L*N / dilation). The output might be zero padded at the start 25 | """ 26 | 27 | [n, c, l] = x.size() 28 | dilation_factor = dilation / init_dilation 29 | if dilation_factor == 1: 30 | return x 31 | 32 | # zero padding for reshaping 33 | new_l = int(np.ceil(l / dilation_factor) * dilation_factor) 34 | if new_l != l: 35 | l = new_l 36 | x = constant_pad_1d(x, new_l, pad_start=pad_start) 37 | 38 | l_old = int(round(l / dilation_factor)) 39 | n_old = int(round(n * dilation_factor)) 40 | l = math.ceil(l * init_dilation / dilation) 41 | n = math.ceil(n * dilation / init_dilation) 42 | 43 | # reshape according to dilation 44 | x = x.permute(1, 2, 0).contiguous() # (n, c, l) -> (c, l, n) 45 | x = x.view(c, l, n) 46 | x = x.permute(2, 0, 1).contiguous() # (c, l, n) -> (n, c, l) 47 | 48 | return x 49 | 50 | 51 | class DilatedQueue: 52 | def __init__(self, max_length, data=None, dilation=1, num_deq=1, num_channels=1, dtype=torch.FloatTensor): 53 | self.in_pos = 0 54 | self.out_pos = 0 55 | self.num_deq = num_deq 56 | self.num_channels = num_channels 57 | self.dilation = dilation 58 | self.max_length = max_length 59 | self.data = data 60 | self.dtype = dtype 61 | if data == None: 62 | self.data = Variable(dtype(num_channels, max_length).zero_()) 63 | 64 | def enqueue(self, input): 65 | self.data[:, self.in_pos] = input 66 | self.in_pos = (self.in_pos + 1) % self.max_length 67 | 68 | def dequeue(self, num_deq=1, dilation=1): 69 | # | 70 | # |6|7|8|1|2|3|4|5| 71 | # | 72 | start = self.out_pos - ((num_deq - 1) * dilation) 73 | if start < 0: 74 | t1 = self.data[:, start::dilation] 75 | t2 = self.data[:, self.out_pos % dilation:self.out_pos + 1:dilation] 76 | t = torch.cat((t1, t2), 1) 77 | else: 78 | t = self.data[:, start:self.out_pos + 1:dilation] 79 | 80 | self.out_pos = (self.out_pos + 1) % self.max_length 81 | return t 82 | 83 | def reset(self): 84 | self.data = Variable(self.dtype(self.num_channels, self.max_length).zero_()) 85 | self.in_pos = 0 86 | self.out_pos = 0 87 | 88 | 89 | def constant_pad_1d(input, 90 | target_size, 91 | value=0, 92 | pad_start=False): 93 | """ 94 | Assumes that padded dim is the 2, based on pytorch specification. 95 | Input: (N,C,Win)(N, C, W_{in})(N,C,Win​) 96 | Output: (N,C,Wout)(N, C, W_{out})(N,C,Wout​) where 97 | :param input: 98 | :param target_size: 99 | :param value: 100 | :param pad_start: 101 | :return: 102 | """ 103 | num_pad = target_size - input.size(2) 104 | assert num_pad >= 0, 'target size has to be greater than input size' 105 | padding = (num_pad, 0) if pad_start else (0, num_pad) 106 | return torch.nn.ConstantPad1d(padding, value)(input) 107 | 108 | 109 | 110 | ################################### 111 | # https://github.com/vincentherrmann/pytorch-wavenet/blob/master/wavenet_model.py 112 | 113 | import os 114 | import os.path 115 | import time 116 | import torch.nn.functional as F 117 | from torch import nn 118 | 119 | 120 | 121 | class WaveNet(nn.Module): 122 | """ 123 | A Complete Wavenet Model 124 | Args: 125 | layers (Int): Number of layers in each block 126 | blocks (Int): Number of wavenet blocks of this model 127 | dilation_channels (Int): Number of channels for the dilated convolution 128 | residual_channels (Int): Number of channels for the residual connection 129 | skip_channels (Int): Number of channels for the skip connections 130 | classes (Int): Number of possible values each sample can have 131 | output_length (Int): Number of samples that are generated for each input 132 | kernel_size (Int): Size of the dilation kernel 133 | dtype: Parameter type of this model 134 | Shape: 135 | - Input: :math:`(N, C_{in}, L_{in})` 136 | - Output: :math:`()` 137 | L should be the length of the receptive field 138 | """ 139 | 140 | def __init__(self, 141 | layers=10, 142 | blocks=4, 143 | dilation_channels=32, 144 | residual_channels=32, 145 | skip_channels=256, 146 | end_channels=256, 147 | classes=256, 148 | output_length=32, 149 | kernel_size=2, 150 | dtype=torch.FloatTensor, 151 | bias=False): 152 | 153 | super(WaveNet, self).__init__() 154 | 155 | self.layers = layers 156 | self.blocks = blocks 157 | self.dilation_channels = dilation_channels 158 | self.residual_channels = residual_channels 159 | self.skip_channels = skip_channels 160 | self.classes = classes 161 | self.kernel_size = kernel_size 162 | self.dtype = dtype 163 | 164 | # build model 165 | receptive_field = 1 166 | init_dilation = 1 167 | 168 | self.dilations = [] 169 | self.dilated_queues = [] 170 | # self.main_convs = nn.ModuleList() 171 | self.filter_convs = nn.ModuleList() 172 | self.gate_convs = nn.ModuleList() 173 | self.residual_convs = nn.ModuleList() 174 | self.skip_convs = nn.ModuleList() 175 | 176 | # 1x1 convolution to create channels 177 | self.start_conv = nn.Conv1d(in_channels=self.classes, 178 | out_channels=residual_channels, 179 | kernel_size=1, 180 | bias=bias) 181 | 182 | for b in range(blocks): 183 | additional_scope = kernel_size - 1 184 | new_dilation = 1 185 | for i in range(layers): 186 | # dilations of this layer 187 | self.dilations.append((new_dilation, init_dilation)) 188 | 189 | # dilated queues for fast generation 190 | self.dilated_queues.append(DilatedQueue(max_length=(kernel_size - 1) * new_dilation + 1, 191 | num_channels=residual_channels, 192 | dilation=new_dilation, 193 | dtype=dtype)) 194 | 195 | # dilated convolutions 196 | self.filter_convs.append(nn.Conv1d(in_channels=residual_channels, 197 | out_channels=dilation_channels, 198 | kernel_size=kernel_size, 199 | bias=bias)) 200 | 201 | self.gate_convs.append(nn.Conv1d(in_channels=residual_channels, 202 | out_channels=dilation_channels, 203 | kernel_size=kernel_size, 204 | bias=bias)) 205 | 206 | # 1x1 convolution for residual connection 207 | self.residual_convs.append(nn.Conv1d(in_channels=dilation_channels, 208 | out_channels=residual_channels, 209 | kernel_size=1, 210 | bias=bias)) 211 | 212 | # 1x1 convolution for skip connection 213 | self.skip_convs.append(nn.Conv1d(in_channels=dilation_channels, 214 | out_channels=skip_channels, 215 | kernel_size=1, 216 | bias=bias)) 217 | 218 | receptive_field += additional_scope 219 | additional_scope *= 2 220 | init_dilation = new_dilation 221 | new_dilation *= 2 222 | 223 | self.end_conv_1 = nn.Conv1d(in_channels=skip_channels, 224 | out_channels=end_channels, 225 | kernel_size=1, 226 | bias=True) 227 | 228 | self.end_conv_2 = nn.Conv1d(in_channels=end_channels, 229 | out_channels=classes, 230 | kernel_size=1, 231 | bias=True) 232 | 233 | # self.output_length = 2 ** (layers - 1) 234 | self.output_length = output_length 235 | self.receptive_field = receptive_field 236 | 237 | def wavenet(self, input, dilation_func): 238 | 239 | x = self.start_conv(input) 240 | skip = 0 241 | 242 | # WaveNet layers 243 | for i in range(self.blocks * self.layers): 244 | 245 | # |----------------------------------------| *residual* 246 | # | | 247 | # | |-- conv -- tanh --| | 248 | # -> dilate -|----| * ----|-- 1x1 -- + --> *input* 249 | # |-- conv -- sigm --| | 250 | # 1x1 251 | # | 252 | # ---------------------------------------> + -------------> *skip* 253 | 254 | (dilation, init_dilation) = self.dilations[i] 255 | 256 | residual = dilation_func(x, dilation, init_dilation, i) 257 | 258 | # dilated convolution 259 | filter = self.filter_convs[i](residual) 260 | filter = torch.tanh(filter) 261 | gate = self.gate_convs[i](residual) 262 | gate = torch.sigmoid(gate) 263 | x = filter * gate 264 | 265 | # parametrized skip connection 266 | s = x 267 | if x.size(2) != 1: 268 | s = dilate(x, 1, init_dilation=dilation) 269 | s = self.skip_convs[i](s) 270 | try: 271 | skip = skip[:, :, -s.size(2):] 272 | except: 273 | skip = 0 274 | skip = s + skip 275 | 276 | x = self.residual_convs[i](x) 277 | x = x + residual[:, :, (self.kernel_size - 1):] 278 | 279 | x = F.relu(skip) 280 | x = F.relu(self.end_conv_1(x)) 281 | x = self.end_conv_2(x) 282 | 283 | return x 284 | 285 | def wavenet_dilate(self, input, dilation, init_dilation, i): 286 | x = dilate(input, dilation, init_dilation) 287 | return x 288 | 289 | def queue_dilate(self, input, dilation, init_dilation, i): 290 | queue = self.dilated_queues[i] 291 | queue.enqueue(input.data[0]) 292 | x = queue.dequeue(num_deq=self.kernel_size, 293 | dilation=dilation) 294 | x = x.unsqueeze(0) 295 | 296 | return x 297 | 298 | def forward(self, input): 299 | x = self.wavenet(input, 300 | dilation_func=self.wavenet_dilate) 301 | 302 | # reshape output 303 | # [n, c, l] = x.size() 304 | # l = self.output_length 305 | # x = x[:, :, -l:] 306 | # x = x.transpose(1, 2).contiguous() 307 | # x = x.view(n * l, c) 308 | return x 309 | 310 | def generate(self, 311 | num_samples, 312 | first_samples=None, 313 | temperature=1.): 314 | self.eval() 315 | if first_samples is None: 316 | first_samples = self.dtype(1).zero_() 317 | generated = Variable(first_samples, volatile=True) 318 | 319 | num_pad = self.receptive_field - generated.size(0) 320 | if num_pad > 0: 321 | generated = constant_pad_1d(generated, self.scope, pad_start=True) 322 | print("pad zero") 323 | 324 | for i in range(num_samples): 325 | input = Variable(torch.FloatTensor(1, self.classes, self.receptive_field).zero_()) 326 | input = input.scatter_(1, generated[-self.receptive_field:].view(1, -1, self.receptive_field), 1.) 327 | 328 | x = self.wavenet(input, 329 | dilation_func=self.wavenet_dilate)[:, :, -1].squeeze() 330 | 331 | if temperature > 0: 332 | x /= temperature 333 | prob = F.softmax(x, dim=0) 334 | prob = prob.cpu() 335 | np_prob = prob.data.numpy() 336 | x = np.random.choice(self.classes, p=np_prob) 337 | x = Variable(torch.LongTensor([x])) # np.array([x]) 338 | else: 339 | x = torch.max(x, 0)[1].float() 340 | 341 | generated = torch.cat((generated, x), 0) 342 | 343 | generated = (generated / self.classes) * 2. - 1 344 | mu_gen = mu_law_expansion(generated, self.classes) 345 | 346 | self.train() 347 | return mu_gen 348 | 349 | def generate_fast(self, 350 | num_samples, 351 | first_samples=None, 352 | temperature=1., 353 | regularize=0., 354 | progress_callback=None, 355 | progress_interval=100): 356 | self.eval() 357 | if first_samples is None: 358 | first_samples = torch.LongTensor(1).zero_() + (self.classes // 2) 359 | first_samples = Variable(first_samples) 360 | 361 | # reset queues 362 | for queue in self.dilated_queues: 363 | queue.reset() 364 | 365 | num_given_samples = first_samples.size(0) 366 | total_samples = num_given_samples + num_samples 367 | 368 | input = Variable(torch.FloatTensor(1, self.classes, 1).zero_()) 369 | input = input.scatter_(1, first_samples[0:1].view(1, -1, 1), 1.) 370 | 371 | # fill queues with given samples 372 | for i in range(num_given_samples - 1): 373 | x = self.wavenet(input, 374 | dilation_func=self.queue_dilate) 375 | input.zero_() 376 | input = input.scatter_(1, first_samples[i + 1:i + 2].view(1, -1, 1), 1.).view(1, self.classes, 1) 377 | 378 | # progress feedback 379 | if i % progress_interval == 0: 380 | if progress_callback is not None: 381 | progress_callback(i, total_samples) 382 | 383 | # generate new samples 384 | generated = np.array([]) 385 | regularizer = torch.pow(Variable(torch.arange(self.classes)) - self.classes / 2., 2) 386 | regularizer = regularizer.squeeze() * regularize 387 | tic = time.time() 388 | for i in range(num_samples): 389 | x = self.wavenet(input, 390 | dilation_func=self.queue_dilate).squeeze() 391 | 392 | x -= regularizer 393 | 394 | if temperature > 0: 395 | # sample from softmax distribution 396 | x /= temperature 397 | prob = F.softmax(x, dim=0) 398 | prob = prob.cpu() 399 | np_prob = prob.data.numpy() 400 | x = np.random.choice(self.classes, p=np_prob) 401 | x = np.array([x]) 402 | else: 403 | # convert to sample value 404 | x = torch.max(x, 0)[1][0] 405 | x = x.cpu() 406 | x = x.data.numpy() 407 | 408 | o = (x / self.classes) * 2. - 1 409 | generated = np.append(generated, o) 410 | 411 | # set new input 412 | x = Variable(torch.from_numpy(x).type(torch.LongTensor)) 413 | input.zero_() 414 | input = input.scatter_(1, x.view(1, -1, 1), 1.).view(1, self.classes, 1) 415 | 416 | if (i + 1) == 100: 417 | toc = time.time() 418 | print("one generating step does take approximately " + str((toc - tic) * 0.01) + " seconds)") 419 | 420 | # progress feedback 421 | if (i + num_given_samples) % progress_interval == 0: 422 | if progress_callback is not None: 423 | progress_callback(i + num_given_samples, total_samples) 424 | 425 | self.train() 426 | mu_gen = mu_law_expansion(generated, self.classes) 427 | return mu_gen 428 | 429 | def parameter_count(self): 430 | par = list(self.parameters()) 431 | s = sum([np.prod(list(d.size())) for d in par]) 432 | return s 433 | 434 | def cpu(self, type=torch.FloatTensor): 435 | self.dtype = type 436 | for q in self.dilated_queues: 437 | q.dtype = self.dtype 438 | super().cpu() 439 | 440 | 441 | def load_latest_model_from(location, use_cuda=True): 442 | files = [location + "/" + f for f in os.listdir(location)] 443 | newest_file = max(files, key=os.path.getctime) 444 | print("load model " + newest_file) 445 | 446 | if use_cuda: 447 | model = torch.load(newest_file) 448 | else: 449 | model = load_to_cpu(newest_file) 450 | 451 | return model 452 | 453 | 454 | def load_to_cpu(path): 455 | model = torch.load(path, map_location=lambda storage, loc: storage) 456 | model.cpu() 457 | return model 458 | 459 | 460 | 461 | ########################## 462 | 463 | def test(): 464 | 465 | model = WaveNet( 466 | layers=6, 467 | blocks=3, 468 | dilation_channels=32, 469 | residual_channels=32, 470 | skip_channels=1024, 471 | classes=9, 472 | end_channels=512, 473 | bias=True 474 | ) 475 | 476 | x = torch.randn(2, 9, 250) 477 | print("Input shape:", x.shape) # torch.Size([2, 9, 250]) 478 | y = model(x) 479 | print("Output shape:", y.shape) # torch.Size([2, 9, 128]) 480 | 481 | 482 | if __name__ == "__main__": 483 | test() -------------------------------------------------------------------------------- /torchsignal/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .MultitaskSSVEP import MultitaskSSVEP 2 | from .CompactEEGNet import CompactEEGNet 3 | from .Performer import Performer, PerformerLM 4 | from .WaveNet import WaveNet 5 | from .MIEEGNet import MIEEGNet -------------------------------------------------------------------------------- /torchsignal/model/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinglescode/torchsignal/6172bc2b18eeafa9464cfba678e9c02ea4ed5e2a/torchsignal/model/common/__init__.py -------------------------------------------------------------------------------- /torchsignal/model/common/conv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Common convolutions 3 | """ 4 | import torch 5 | from torch import nn 6 | from torch.nn.utils import weight_norm 7 | import torch.nn.functional as F 8 | 9 | from typing import List 10 | import math 11 | 12 | 13 | class Conv2d(nn.Module): 14 | """ 15 | Input: 4-dim tensor 16 | Shape [batch, in_channels, H, W] 17 | Return: 4-dim tensor 18 | Shape [batch, out_channels, H, W] 19 | 20 | Args: 21 | in_channels : int 22 | Should match input `channel` 23 | out_channels : int 24 | Return tensor with `out_channels` 25 | kernel_size : int or 2-dim tuple 26 | stride : int or 2-dim tuple, default: 1 27 | padding : int or 2-dim tuple or True 28 | Apply `padding` if given int or 2-dim tuple. Perform TensorFlow-like 'SAME' padding if True 29 | dilation : int or 2-dim tuple, default: 1 30 | groups : int or 2-dim tuple, default: 1 31 | w_in: int, optional 32 | The size of `W` axis. If given, `w_out` is available. 33 | 34 | Usage: 35 | x = torch.randn(1, 22, 1, 256) 36 | conv1 = Conv2dSamePadding(22, 64, kernel_size=17, padding=True, w_in=256) 37 | y = conv1(x) 38 | """ 39 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding="SAME", dilation=1, groups=1, w_in=None): 40 | super().__init__() 41 | 42 | padding = padding 43 | self.kernel_size = kernel_size = kernel_size 44 | self.stride = stride = stride 45 | self.dilation = dilation = dilation 46 | 47 | self.padding_same = False 48 | if padding == "SAME": 49 | self.padding_same = True 50 | padding = (0,0) 51 | 52 | if isinstance(padding, int): 53 | padding = (padding, padding) 54 | 55 | if isinstance(kernel_size, int): 56 | self.kernel_size = kernel_size = (kernel_size, kernel_size) 57 | 58 | if isinstance(stride, int): 59 | self.stride = stride = (stride, stride) 60 | 61 | if isinstance(dilation, int): 62 | self.dilation = dilation = (dilation, dilation) 63 | 64 | self.conv = nn.Conv2d( 65 | in_channels, 66 | out_channels, 67 | kernel_size=kernel_size, 68 | stride=stride, 69 | padding=0 if padding==True else padding, 70 | dilation=dilation, 71 | groups=groups 72 | ) 73 | 74 | if w_in is not None: 75 | self.w_out = int( ((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1]-1)-1) / 1) + 1 ) 76 | if self.padding_same == "SAME": # if SAME, then replace, w_out = w_in, obviously 77 | self.w_out = w_in 78 | 79 | def forward(self, x): 80 | if self.padding_same == True: 81 | x = self.pad_same(x, self.kernel_size, self.stride, self.dilation) 82 | return self.conv(x) 83 | 84 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 85 | def get_same_padding(self, x: int, k: int, s: int, d: int): 86 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 87 | 88 | # Dynamically pad input x with 'SAME' padding for conv with specified args 89 | def pad_same(self, x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 90 | ih, iw = x.size()[-2:] 91 | pad_h, pad_w = self.get_same_padding(ih, k[0], s[0], d[0]), self.get_same_padding(iw, k[1], s[1], d[1]) 92 | if pad_h > 0 or pad_w > 0: 93 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 94 | return x 95 | 96 | 97 | class Conv2dBlockELU(nn.Module): 98 | def __init__(self, in_channels, out_channels, kernel_size, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1, activation=nn.ELU, w_in=None): 99 | super(Conv2dBlockELU, self).__init__() 100 | self.conv = nn.Sequential( 101 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups), 102 | nn.BatchNorm2d(out_channels), 103 | activation(inplace=True) 104 | ) 105 | 106 | if w_in is not None: 107 | self.w_out = int( ((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1]-1)-1) / 1) + 1 ) 108 | 109 | def forward(self, x): 110 | return self.conv(x) 111 | 112 | 113 | class DepthwiseConv2d(nn.Module): 114 | def __init__(self, in_channels, out_channels, kernel_size, depth=1, padding=0, bias=False): 115 | super(DepthwiseConv2d, self).__init__() 116 | self.depthwise = nn.Conv2d(in_channels, out_channels*depth, kernel_size=kernel_size, padding=padding, groups=in_channels, bias=bias) 117 | 118 | def forward(self, x): 119 | x = self.depthwise(x) 120 | return x 121 | 122 | 123 | class SeparableConv2d(nn.Module): 124 | def __init__(self, in_channels, out_channels, kernel_size, bias=False): 125 | super(SeparableConv2d, self).__init__() 126 | 127 | if isinstance(kernel_size, int): 128 | padding = kernel_size // 2 129 | 130 | if isinstance(kernel_size, tuple): 131 | padding = ( 132 | kernel_size[0]//2 if kernel_size[0]-1 != 0 else 0, 133 | kernel_size[1]//2 if kernel_size[1]-1 != 0 else 0 134 | ) 135 | 136 | self.depthwise = DepthwiseConv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, padding=padding, bias=bias) 137 | self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias) 138 | 139 | def forward(self, x): 140 | x = self.depthwise(x) 141 | x = self.pointwise(x) 142 | return x 143 | 144 | 145 | class SelfAttention(nn.Module): 146 | """Self attention layer 147 | 148 | Inputs : 149 | x : input feature maps( B X C X W X H) 150 | Returns : 151 | out : self attention value + input feature 152 | attention: B X N X N (N is Width*Height) 153 | Usage: 154 | selfattn = SelfAttention(11) 155 | x = torch.randn(2, 11, 128, 128) 156 | print("Input shape:", x.shape) 157 | y, attention = selfattn(x) 158 | print("Output shape:", y.shape) 159 | print("Attention shape:", attention.shape) 160 | """ 161 | def __init__(self, in_dim, activation=nn.ReLU): 162 | super().__init__() 163 | self.chanel_in = in_dim 164 | self.activation = activation 165 | 166 | self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size= 1) 167 | self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size= 1) 168 | self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size= 1) 169 | self.gamma = nn.Parameter(torch.zeros(1)) 170 | 171 | self.softmax = nn.Softmax(dim=-1) # 172 | 173 | def forward(self,x): 174 | m_batchsize, C, width, height = x.size() 175 | proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N) 176 | proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H) 177 | energy = torch.bmm(proj_query,proj_key) # transpose check 178 | attention = self.softmax(energy) # BX (N) X (N) 179 | proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N 180 | 181 | out = torch.bmm(proj_value,attention.permute(0,2,1) ) 182 | out = out.view(m_batchsize,C,width,height) 183 | 184 | out = self.gamma*out + x 185 | return out, attention 186 | 187 | 188 | ##### 189 | # Sequence Modeling Benchmarks and Temporal Convolutional Networks (TCN) 190 | 191 | class Chomp1d(nn.Module): 192 | def __init__(self, chomp_size): 193 | super(Chomp1d, self).__init__() 194 | self.chomp_size = chomp_size 195 | 196 | def forward(self, x): 197 | return x[:, :, :-self.chomp_size].contiguous() 198 | 199 | 200 | class TemporalBlock(nn.Module): 201 | def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2): 202 | super(TemporalBlock, self).__init__() 203 | self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, 204 | stride=stride, padding=padding, dilation=dilation)) 205 | self.chomp1 = Chomp1d(padding) 206 | self.relu1 = nn.ReLU() 207 | self.dropout1 = nn.Dropout(dropout) 208 | 209 | self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, 210 | stride=stride, padding=padding, dilation=dilation)) 211 | self.chomp2 = Chomp1d(padding) 212 | self.relu2 = nn.ReLU() 213 | self.dropout2 = nn.Dropout(dropout) 214 | 215 | self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, 216 | self.conv2, self.chomp2, self.relu2, self.dropout2) 217 | self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None 218 | self.relu = nn.ReLU() 219 | self.init_weights() 220 | 221 | def init_weights(self): 222 | self.conv1.weight.data.normal_(0, 0.01) 223 | self.conv2.weight.data.normal_(0, 0.01) 224 | if self.downsample is not None: 225 | self.downsample.weight.data.normal_(0, 0.01) 226 | 227 | def forward(self, x): 228 | out = self.net(x) 229 | res = x if self.downsample is None else self.downsample(x) 230 | return self.relu(out + res) 231 | 232 | 233 | class TemporalConvNet(nn.Module): 234 | """ 235 | TCN layer 236 | An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling 237 | https://arxiv.org/abs/1803.01271 238 | https://github.com/locuslab/TCN 239 | 240 | Usage: 241 | tcn = TemporalConvNet( 242 | num_channels=11 243 | ) 244 | x = torch.randn(2, 11, 250) 245 | print("Input shape:", x.shape) 246 | y = tcn(x) 247 | print("Output shape:", y.shape) 248 | 249 | """ 250 | def __init__(self, num_channels, kernel_size=7, dropout=0.1, nhid=32, levels=8): 251 | super(TemporalConvNet, self).__init__() 252 | 253 | channel_sizes = [nhid] * levels 254 | 255 | layers = [] 256 | num_levels = len(channel_sizes) 257 | for i in range(num_levels): 258 | dilation_size = 2 ** i 259 | in_channels = num_channels if i == 0 else channel_sizes[i-1] 260 | out_channels = channel_sizes[i] 261 | layers += [ 262 | TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, 263 | padding=(kernel_size-1) * dilation_size, dropout=dropout)] 264 | 265 | self.network = nn.Sequential(*layers) 266 | 267 | def forward(self, x): 268 | return self.network(x) 269 | -------------------------------------------------------------------------------- /torchsignal/model/common/utils.py: -------------------------------------------------------------------------------- 1 | def count_params(model): 2 | """ 3 | Count number of trainable parameters in PyTorch model 4 | """ 5 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 6 | return num_params -------------------------------------------------------------------------------- /torchsignal/trainer/multitask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | import numpy as np 4 | import time 5 | import copy 6 | from sklearn.metrics import accuracy_score, f1_score 7 | 8 | 9 | class Multitask_Trainer(object): 10 | 11 | def __init__(self, model, model_name='model', optimizer=None, criterion=None, scheduler=None, learning_rate=0.001, device='cpu', num_classes=4, multitask_learning=True, patience=10, verbose=False, print_training_metric=False, warmup_lr=False, **kwargs): 12 | 13 | SEED = 12 14 | torch.manual_seed(SEED) 15 | np.random.seed(SEED) 16 | 17 | self.device = device 18 | self.verbose = verbose 19 | self.print_training_metric = print_training_metric 20 | self.num_classes = num_classes 21 | self.model_path = '_tmp_models/'+str(model_name)+'.pth' 22 | self.learning_rate = learning_rate 23 | self.warmup_lr = warmup_lr 24 | 25 | self.multitask_learning = multitask_learning 26 | 27 | self.model = model.to(self.device) 28 | 29 | params_to_update = self.get_parameters() 30 | 31 | if optimizer: 32 | self.optimizer = optimizer 33 | else: 34 | self.optimizer = optim.Adam(params_to_update, lr=learning_rate, weight_decay=0.05) 35 | 36 | if criterion: 37 | self.criterion = criterion 38 | else: 39 | if multitask_learning: 40 | self.criterion = [] 41 | 42 | for i in range(self.num_classes): 43 | class_weights = torch.FloatTensor([1, self.num_classes]).to(device) 44 | self.criterion.append(nn.CrossEntropyLoss(weight=class_weights)) 45 | 46 | else: 47 | self.criterion = nn.CrossEntropyLoss() 48 | 49 | if scheduler: 50 | self.scheduler = scheduler 51 | elif self.warmup_lr: 52 | warmup_lr_multiplier = kwargs["warmup_lr_multiplier"] if "warmup_lr_multiplier" in kwargs else 1 53 | warmup_lr_total_epoch = kwargs["warmup_lr_total_epoch"] if "warmup_lr_total_epoch" in kwargs else 5 54 | scheduler_lr = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=patience, gamma=0.1) 55 | # scheduler_lr = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=patience) 56 | self.scheduler = GradualWarmupScheduler(self.optimizer, multiplier=warmup_lr_multiplier, total_epoch=warmup_lr_total_epoch, after_scheduler=scheduler_lr) 57 | else: 58 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=patience) 59 | 60 | 61 | def computes_accuracy(self, outputs, targets, k=1): 62 | _, preds = outputs.topk(k, 1, True, True) 63 | preds = preds.t() 64 | correct = preds.eq(targets.view(1, -1).expand_as(preds)) 65 | correct_k = correct[:k].view(-1).float() 66 | return correct_k 67 | 68 | def get_preds(self, outputs, k=1): 69 | _, preds = outputs.topk(k, 1, True, True) 70 | preds = preds.t() 71 | return preds[0] 72 | 73 | def train(self, data_loader, topk_accuracy): 74 | self.model.train() 75 | return self._loop(data_loader, train_mode=True, topk_accuracy=topk_accuracy) 76 | 77 | def validate(self, data_loader, topk_accuracy): 78 | self.model.eval() 79 | return self._loop(data_loader, train_mode=False, topk_accuracy=topk_accuracy) 80 | 81 | def get_parameters(self): 82 | if self.verbose: 83 | print("Layers with params to learn:") 84 | params_to_update = [] 85 | for name, param in self.model.named_parameters(): 86 | if param.requires_grad == True: 87 | params_to_update.append(param) 88 | if self.verbose: 89 | print("\t",name) 90 | if self.verbose: 91 | print('\t', len(params_to_update), 'layers') 92 | return params_to_update 93 | 94 | def fit(self, dataloaders_dict, num_epochs=10, early_stopping=5, topk_accuracy=1, min_num_epoch=10, save_model=False): 95 | if self.verbose: 96 | print("-------") 97 | print("Starting training, on device:", self.device) 98 | 99 | time_fit_start = time.time() 100 | train_losses, test_losses, train_accuracies, test_accuracies = [], [], [], [] 101 | early_stopping_counter = early_stopping 102 | 103 | best_epoch_info = { 104 | 'model_wts':copy.deepcopy(self.model.state_dict()), 105 | 'loss':1e10 106 | } 107 | 108 | for epoch in range(num_epochs): 109 | time_epoch_start = time.time() 110 | 111 | train_loss, train_acc, train_classification_f1 = self.train(dataloaders_dict['train'], topk_accuracy) 112 | val_loss, val_acc, val_classification_f1 = self.validate(dataloaders_dict['val'], topk_accuracy) 113 | 114 | train_losses.append(train_loss) 115 | test_losses.append(val_loss) 116 | train_accuracies.append(train_acc) 117 | test_accuracies.append(val_acc) 118 | 119 | improvement = False 120 | if val_loss < best_epoch_info['loss']: 121 | improvement = True 122 | best_epoch_info = { 123 | 'model_wts':copy.deepcopy(self.model.state_dict()), 124 | 'loss':val_loss, 125 | 'epoch':epoch, 126 | 'metrics':{ 127 | 'train_loss':train_loss, 128 | 'val_loss':val_loss, 129 | 'train_acc':train_acc, 130 | 'val_acc':val_acc, 131 | 'train_classification_f1':train_classification_f1, 132 | 'val_classification_f1':val_classification_f1 133 | } 134 | } 135 | 136 | if early_stopping and epoch > min_num_epoch: 137 | if improvement: 138 | early_stopping_counter = early_stopping 139 | else: 140 | early_stopping_counter -= 1 141 | 142 | if early_stopping_counter <= 0: 143 | if self.verbose: 144 | print("Early Stop") 145 | break 146 | if val_loss < 0: 147 | print('val loss negative') 148 | break 149 | 150 | if self.verbose: 151 | print("Epoch {:2} in {:.0f}s || Train loss={:.3f}, acc={:.3f}, f1={:.3f} | Val loss={:.3f}, acc={:.3f}, f1={:.3f} | LR={:.1e} | best={} | improvement={}-{}".format( 152 | epoch+1, 153 | time.time() - time_epoch_start, 154 | train_loss, 155 | train_acc, 156 | train_classification_f1, 157 | val_loss, 158 | val_acc, 159 | val_classification_f1, 160 | self.optimizer.param_groups[0]['lr'], 161 | int(best_epoch_info['epoch'])+1, 162 | improvement, 163 | early_stopping_counter) 164 | ) 165 | 166 | if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): 167 | self.scheduler.step(val_loss) 168 | # elif isinstance(self.scheduler, GradualWarmupScheduler): 169 | # self.scheduler.step(val_loss, epoch) 170 | else: 171 | self.scheduler.step() 172 | 173 | 174 | 175 | self.model.load_state_dict(best_epoch_info['model_wts']) 176 | 177 | if self.print_training_metric: 178 | print() 179 | time_elapsed = time.time() - time_fit_start 180 | print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 181 | 182 | print('Epoch with lowest val loss:', best_epoch_info['epoch']) 183 | for m in best_epoch_info['metrics']: 184 | print('{}: {:.5f}'.format(m, best_epoch_info['metrics'][m])) 185 | print() 186 | 187 | if save_model: 188 | torch.save(self.model.state_dict(), self.model_path) 189 | 190 | 191 | def _loop(self, data_loader, train_mode=True, topk_accuracy=1): 192 | running_loss = 0.0 193 | running_corrects = 0 194 | total_data_count = 0 195 | y_true = [] 196 | y_pred = [] 197 | 198 | for X, Y in data_loader: 199 | inputs = X.to(self.device) 200 | 201 | if self.multitask_learning: 202 | labels = [] 203 | for i in range(self.num_classes): 204 | labels.append( Y.T[i].long().to(self.device) ) 205 | else: 206 | labels = Y.long().to(self.device) 207 | 208 | if train_mode: 209 | self.optimizer.zero_grad() 210 | 211 | outputs = self.model(inputs) 212 | 213 | if self.multitask_learning: 214 | loss = 0 215 | 216 | for i in range(self.num_classes): 217 | output = outputs[:, i, :] 218 | label = labels[i] 219 | 220 | loss += self.criterion[i](output, label) * output.size(0) 221 | 222 | label = label.data.cpu().numpy() 223 | out = self.get_preds(output, topk_accuracy).cpu().numpy() 224 | 225 | index_label_1 = np.where(label == 1) 226 | 227 | if len(index_label_1) > 0: 228 | label_1s = label[index_label_1] 229 | output_1s = out[index_label_1] 230 | 231 | y_true.extend(label_1s) 232 | y_pred.extend(output_1s) 233 | 234 | else: 235 | loss = self.criterion(outputs, labels) * outputs.size(0) 236 | y_true.extend(labels.data.cpu().numpy()) 237 | y_pred.extend(self.get_preds(outputs, topk_accuracy).cpu().numpy()) 238 | 239 | running_loss += loss.item() * self.num_classes 240 | 241 | if train_mode: 242 | loss.backward() 243 | self.optimizer.step() 244 | 245 | epoch_loss = running_loss / len(y_true) 246 | epoch_acc = accuracy_score(y_true, y_pred) 247 | classification_f1 = 0 248 | if self.multitask_learning: 249 | classification_f1 = np.round(f1_score(y_true, y_pred), 3) 250 | 251 | return epoch_loss, np.round(epoch_acc.item(), 3), classification_f1 252 | 253 | 254 | 255 | 256 | # https://github.com/ildoonet/pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py 257 | 258 | from torch.optim.lr_scheduler import _LRScheduler 259 | from torch.optim.lr_scheduler import ReduceLROnPlateau 260 | 261 | 262 | class GradualWarmupScheduler(_LRScheduler): 263 | """ Gradually warm-up(increasing) learning rate in optimizer. 264 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 265 | Args: 266 | optimizer (Optimizer): Wrapped optimizer. 267 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 268 | total_epoch: target learning rate is reached at total_epoch, gradually 269 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 270 | """ 271 | 272 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 273 | self.multiplier = multiplier 274 | if self.multiplier < 1.: 275 | raise ValueError('multiplier should be greater thant or equal to 1.') 276 | self.total_epoch = total_epoch 277 | self.after_scheduler = after_scheduler 278 | self.finished = False 279 | super(GradualWarmupScheduler, self).__init__(optimizer) 280 | 281 | def get_lr(self): 282 | if self.last_epoch > self.total_epoch: 283 | if self.after_scheduler: 284 | if not self.finished: 285 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 286 | self.finished = True 287 | return self.after_scheduler.get_last_lr() 288 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 289 | 290 | if self.multiplier == 1.0: 291 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 292 | else: 293 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 294 | 295 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 296 | if epoch is None: 297 | epoch = self.last_epoch + 1 298 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 299 | if self.last_epoch <= self.total_epoch: 300 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 301 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 302 | param_group['lr'] = lr 303 | else: 304 | if epoch is None: 305 | self.after_scheduler.step(metrics, None) 306 | else: 307 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 308 | 309 | def step(self, epoch=None, metrics=None): 310 | if type(self.after_scheduler) != ReduceLROnPlateau: 311 | if self.finished and self.after_scheduler: 312 | if epoch is None: 313 | self.after_scheduler.step(None) 314 | else: 315 | self.after_scheduler.step(epoch - self.total_epoch) 316 | self._last_lr = self.after_scheduler.get_last_lr() 317 | else: 318 | return super(GradualWarmupScheduler, self).step(epoch) 319 | else: 320 | self.step_ReduceLROnPlateau(metrics, epoch) -------------------------------------------------------------------------------- /torchsignal/transform/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinglescode/torchsignal/6172bc2b18eeafa9464cfba678e9c02ea4ed5e2a/torchsignal/transform/__init__.py -------------------------------------------------------------------------------- /torchsignal/transform/fft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.fft import fft 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def fast_fourier_transform(signal, sample_rate, plot=False, plot_xlim=[0, 80], plot_ylim=None, plot_label=''): 7 | r""" 8 | Use Fourier transforms to find the frequency components of a signal buried in noise. 9 | Reference: https://www.mathworks.com/help/matlab/ref/fft.html 10 | Args: 11 | signal : ndarray, shape (time,) 12 | Single input signal in time domain 13 | sample_rate: int 14 | Sampling frequency 15 | plot : boolean, default: False 16 | To plot the single-sided amplitude spectrum 17 | plot_xlim : array of shape [lower, upper], default: [0,80] 18 | Set a limit on the X-axis between lower and upper bound 19 | plot_label : string 20 | a text label for this signal in plot 21 | Returns: 22 | P1 : ndarray, shape ((signal_length/2+1),) 23 | frequency domain 24 | Example: 25 | Fs = 1000 # Sampling frequency 26 | L = 4000 # Length of signal 27 | t = np.arange(0, (L/(Fs)), step=1.0/(Fs)) 28 | S = 0.7*np.sin(2*np.pi*10*t) + np.sin(2*np.pi*12*t) # Signal 29 | P = time_to_frequency(S, sample_rate=Fs, signal_length=L, plot=True, plot_xlim=[0,20]) 30 | Dependencies: 31 | np : numpy package 32 | plt : matplotlib.pyplot 33 | fft : scipy.fft.fft 34 | """ 35 | 36 | signal_length = signal.shape[0] 37 | 38 | if signal_length % 2 != 0: 39 | signal_length = signal_length+1 40 | 41 | y = fft(signal) 42 | p2 = np.abs(y/signal_length) 43 | p1 = p2[0:round(signal_length/2+1)] 44 | p1[1:-1] = 2*p1[1:-1] 45 | 46 | if plot: 47 | # TODO change to this chart, https://www.oreilly.com/library/view/elegant-scipy/9781491922927/ch04.html 48 | f = sample_rate*np.arange(0, (signal_length/2)+1)/signal_length 49 | plt.plot(f, p1, label=plot_label) 50 | plt.xlim(plot_xlim) 51 | if plot_ylim is not None: 52 | plt.ylim(plot_ylim) 53 | 54 | return p1 55 | -------------------------------------------------------------------------------- /torchsignal/transform/segment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def segment_signal(signal, window_len, shift_len, sample_rate, add_segment_axis=True, verbose=False): 5 | r""" 6 | Divide a signal time domain into length of `window_len`. 7 | 8 | Args: 9 | signal : ndarray, shape (trial,channel,time) 10 | Input signal by trials in time domain 11 | window_len : int 12 | Window/segment length (in seconds) 13 | shift_len : int 14 | Shift of the window (in time samples). Note: indirectly specifies overlap 15 | sample_rate : int 16 | Sampling frequency 17 | add_segment_axis : boolean, default: True 18 | If True, segmented shape is (trial,channel,#segments,time), otherwise (trial,channel,time) 19 | verbose : boolean, default: False 20 | Print details 21 | Returns: 22 | segmented : ndarray, shape (trial,channel,#segments,time) or (trial,channel,time) 23 | Segmented matrix 24 | Example: 25 | Fs = 1000 # Sampling frequency 26 | num_trials = 100 27 | num_channels = 9 28 | num_timesamples = 4000 29 | S = np.zeros((num_trials,num_channels,num_timesamples)) # Signal 30 | 31 | window_len = 1 32 | shift_len = 1 33 | segmented = segment_signal(S, window_len, shift_len, Fs, True) 34 | Dependencies: 35 | np : numpy package 36 | buffer : function 37 | """ 38 | 39 | assert len(signal.shape) == 3, "signal shape must be (trial,channel,time)" 40 | 41 | duration = int(window_len*sample_rate) 42 | # (window_len-shift_len)*sample_rate 43 | data_overlap = (window_len*sample_rate-shift_len) 44 | 45 | num_segments = int( 46 | np.ceil((signal.shape[2]-data_overlap)/(duration-data_overlap))) 47 | 48 | if add_segment_axis: # return (trial,channel,segments,time) 49 | segmented = np.zeros( 50 | (signal.shape[0], signal.shape[1], num_segments, duration)) 51 | for trial in range(0, signal.shape[0]): 52 | for channel in range(0, signal.shape[1]): 53 | segmented[trial, channel, :, :] = buffer( 54 | signal[trial, channel], duration, data_overlap, num_segments) 55 | else: # return (trial,channel,time) 56 | segmented = np.zeros( 57 | (signal.shape[0]*num_segments, signal.shape[1], duration)) 58 | for trial in range(0, signal.shape[0]): 59 | for channel in range(0, signal.shape[1]): 60 | signal_buffer = buffer( 61 | signal[trial, channel], duration, data_overlap, num_segments) 62 | for segment in range(0, signal_buffer.shape[0]): 63 | index = (trial*num_segments)+segment 64 | segmented[index, channel, :] = signal_buffer[segment] 65 | 66 | if verbose: 67 | print('Duration', duration) 68 | print('Overlap', data_overlap) 69 | print('#segments', num_segments) 70 | print('Shape from', signal[0, 0].shape, 'to', buffer( 71 | signal[0, 0], duration, data_overlap, num_segments).shape) 72 | print('Input: Signal shape', signal.shape) 73 | print('Output: Segmented signal shape', segmented.shape) 74 | 75 | return segmented 76 | 77 | 78 | def buffer(signal, duration, data_overlap, number_segments, verbose=False): 79 | r""" 80 | Divide a single signal time domain into length of `duration`. 81 | 82 | Args: 83 | signal : ndarray, shape (time,) 84 | Single input signal in time domain 85 | duration : int 86 | Window/segment length in time samples 87 | data_overlap : int 88 | Segment length that is overlapped in time samples 89 | number_segments : int 90 | Number of segments 91 | verbose : boolean, default: False 92 | Print details 93 | Returns: 94 | segmented : ndarray, shape (#segments,time) 95 | Segmented signals 96 | Example: 97 | Fs = 1000 # Sampling frequency 98 | L = 4000 # Length of signal 99 | t = np.arange(0, (L/(Fs)), step=1.0/(Fs)) 100 | S = 0.7*np.sin(2*np.pi*10*t) + np.sin(2*np.pi*12*t) # Signal 101 | 102 | duration = 2000 # in time samples 103 | data_overlap = 1000 # in time samples 104 | segmented = buffer(S,duration,data_overlap, True) 105 | Dependencies: 106 | np : numpy package 107 | """ 108 | 109 | temp_buf = [signal[i:i+duration] 110 | for i in range(0, len(signal), (duration-int(data_overlap)))] 111 | temp_buf[number_segments-1] = np.pad( 112 | temp_buf[number_segments-1], 113 | (0, duration-temp_buf[number_segments-1].shape[0]), 114 | 'constant') 115 | segmented = np.vstack(temp_buf[0:number_segments]) 116 | if verbose: 117 | print('Input: Signal shape', signal.shape) 118 | print('Output: Segmented signal shape', segmented.shape) 119 | 120 | return segmented 121 | -------------------------------------------------------------------------------- /torchsignal/transform/spectrogram.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | import torch 3 | from torch import Tensor 4 | 5 | 6 | class Spectrogram(torch.nn.Module): 7 | r"""Create a spectrogram from a audio signal. 8 | Code taken from TORCHAUDIO.TRANSFORMS[https://pytorch.org/audio/stable/transforms.html] 9 | TORCH.STFT [https://pytorch.org/docs/stable/generated/torch.stft.html] only accepts input either a 1-D time sequence or a 2-D batch of time sequences. 10 | 11 | Parameters: 12 | waveform (Tensor) – Tensor of audio of dimension (…, time). 13 | 14 | Returns: 15 | Dimension (…, freq, time), where freq is n_fft // 2 + 1 where n_fft is the number of Fourier bins, and time is the number of window hops (n_frame). 16 | 17 | Return type: 18 | Tensor 19 | 20 | Args: 21 | n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``) 22 | win_length (int or None, optional): Window size. (Default: ``n_fft``) 23 | hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) 24 | pad (int, optional): Two sided padding of signal. (Default: ``0``) 25 | window_fn (Callable[..., Tensor], optional): A function to create a window tensor 26 | that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) 27 | power (float or None, optional): Exponent for the magnitude spectrogram, 28 | (must be > 0) e.g., 1 for energy, 2 for power, etc. 29 | If None, then the complex spectrum is returned instead. (Default: ``2``) 30 | normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``) 31 | wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``) 32 | """ 33 | __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] 34 | 35 | def __init__(self, 36 | n_fft: int = 400, 37 | win_length: Optional[int] = None, 38 | hop_length: Optional[int] = None, 39 | pad: int = 0, 40 | window_fn: Callable[..., Tensor] = torch.hann_window, 41 | power: Optional[float] = 2., 42 | normalized: bool = False, 43 | wkwargs: Optional[dict] = None) -> None: 44 | super(Spectrogram, self).__init__() 45 | self.n_fft = n_fft 46 | # number of FFT bins. the returned STFT result will have n_fft // 2 + 1 47 | # number of frequecies due to onesided=True in torch.stft 48 | self.win_length = win_length if win_length is not None else n_fft 49 | self.hop_length = hop_length if hop_length is not None else self.win_length // 2 50 | window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) 51 | self.register_buffer('window', window) 52 | self.pad = pad 53 | self.power = power 54 | self.normalized = normalized 55 | 56 | def forward(self, waveform: Tensor) -> Tensor: 57 | r""" 58 | Args: 59 | waveform (Tensor): Tensor of audio of dimension (..., time). 60 | 61 | Returns: 62 | Tensor: Dimension (..., freq, time), where freq is 63 | ``n_fft // 2 + 1`` where ``n_fft`` is the number of 64 | Fourier bins, and time is the number of window hops (n_frame). 65 | """ 66 | return self.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length, 67 | self.win_length, self.power, self.normalized) 68 | 69 | def spectrogram( 70 | self, 71 | waveform: Tensor, 72 | pad: int, 73 | window: Tensor, 74 | n_fft: int, 75 | hop_length: int, 76 | win_length: int, 77 | power: Optional[float], 78 | normalized: bool 79 | ) -> Tensor: 80 | r"""Create a spectrogram or a batch of spectrograms from a raw audio signal. 81 | The spectrogram can be either magnitude-only or complex. 82 | Args: 83 | waveform (Tensor): Tensor of audio of dimension (..., time) 84 | pad (int): Two sided padding of signal 85 | window (Tensor): Window tensor that is applied/multiplied to each frame/window 86 | n_fft (int): Size of FFT 87 | hop_length (int): Length of hop between STFT windows 88 | win_length (int): Window size 89 | power (float or None): Exponent for the magnitude spectrogram, 90 | (must be > 0) e.g., 1 for energy, 2 for power, etc. 91 | If None, then the complex spectrum is returned instead. 92 | normalized (bool): Whether to normalize by magnitude after stft 93 | Returns: 94 | Tensor: Dimension (..., freq, time), freq is 95 | ``n_fft // 2 + 1`` and ``n_fft`` is the number of 96 | Fourier bins, and time is the number of window hops (n_frame). 97 | """ 98 | 99 | if pad > 0: 100 | # TODO add "with torch.no_grad():" back when JIT supports it 101 | waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant") 102 | 103 | # pack batch 104 | shape = waveform.size() 105 | waveform = waveform.reshape(-1, shape[-1]) 106 | 107 | # default values are consistent with librosa.core.spectrum._spectrogram 108 | spec_f = torch.stft( 109 | waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True 110 | ) 111 | 112 | # unpack batch 113 | spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:]) 114 | 115 | if normalized: 116 | spec_f /= window.pow(2.).sum().sqrt() 117 | if power is not None: 118 | spec_f = self.complex_norm(spec_f, power=power) 119 | 120 | return spec_f 121 | 122 | def complex_norm( 123 | self, 124 | complex_tensor: Tensor, 125 | power: float = 1.0 126 | ) -> Tensor: 127 | r"""Compute the norm of complex tensor input. 128 | Args: 129 | complex_tensor (Tensor): Tensor shape of `(..., complex=2)` 130 | power (float): Power of the norm. (Default: `1.0`). 131 | Returns: 132 | Tensor: Power of the normed input tensor. Shape of `(..., )` 133 | """ 134 | 135 | # Replace by torch.norm once issue is fixed 136 | # https://github.com/pytorch/pytorch/issues/34279 137 | return complex_tensor.pow(2.).sum(-1).pow(0.5 * power) --------------------------------------------------------------------------------