├── .gitignore ├── LICENSE ├── README.md ├── examples ├── convca_benchmark.py ├── convca_jfpm.py ├── deepconvnet_jfpm.py ├── eegnet_benchmark.py ├── eegnet_beta.py ├── eegnet_jfpm.py ├── fbcca.py ├── load_benchmark_data.py ├── load_benchmark_dataset.py ├── load_beta_data.py ├── load_jfpm_data.py ├── multitask_benchmark.py ├── multitask_beta.py ├── multitask_jfpm.py ├── tcnn_benchmark.py ├── trca_benchmark.py ├── trca_beta.py ├── trca_jfpm.py └── trca_sample.py ├── requirements.txt ├── setup.py ├── splearn ├── __init__.py ├── classes │ └── classifier.py ├── cross_decomposition │ ├── __init__.py │ ├── cca.py │ ├── fbcca.py │ ├── filterbank.py │ ├── reference_frequencies.py │ └── trca.py ├── cross_validate │ └── leave_one_out.py ├── data │ ├── __init__.py │ ├── benchmark.py │ ├── beta.py │ ├── generate.py │ ├── jfpm.py │ ├── multiple_subjects.py │ ├── openbmi.py │ ├── pytorch_dataset.py │ ├── sample │ │ └── ssvep.mat │ ├── sample_ssvep.py │ └── utils.py ├── filter │ ├── __init__.py │ ├── butterworth.py │ ├── cca_spatial_filtering.py │ ├── channels.py │ └── notch.py ├── fourier.py ├── nn │ ├── base │ │ ├── __init__.py │ │ ├── classifier.py │ │ └── lightning.py │ ├── loss.py │ ├── models │ │ ├── ConvCA │ │ │ ├── ConvCA.py │ │ │ ├── ConvCaLighting.py │ │ │ └── __init__.py │ │ ├── DeepConvNet │ │ │ ├── DeepConvNet.py │ │ │ ├── __init__.py │ │ │ └── utils.py │ │ ├── EEGNet │ │ │ ├── CompactEEGNet.py │ │ │ └── __init__.py │ │ ├── Multitask │ │ │ ├── Multitask.py │ │ │ ├── MultitaskClassifier.py │ │ │ └── __init__.py │ │ ├── TimeDomainBasedCNN │ │ │ ├── TimeDomainBasedCNN.py │ │ │ └── __init__.py │ │ └── __init__.py │ ├── modules │ │ ├── conv1d.py │ │ ├── conv2d.py │ │ ├── functional.py │ │ ├── positional_encoding.py │ │ ├── relative_multi_head_attention.py │ │ └── residual_connection_module.py │ ├── optimization.py │ └── utils.py └── utils │ ├── __init__.py │ ├── config.py │ └── logger.py └── tutorials ├── Butterworth Filter.ipynb ├── Canonical Correlation Analysis.ipynb ├── Denoising with Gaussian-smooth filter.ipynb ├── Denoising with mean-smooth filter.ipynb ├── Fourier Transform.ipynb ├── README.md ├── Signal composition - time, sampling rate and frequency.ipynb └── Task-Related Component Analysis.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # System Files 132 | .DS_Store 133 | Thumbs.db 134 | 135 | _devt/ 136 | tensorboard_logs/ 137 | run_logs/ 138 | experiments/ 139 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Python Signal Processing 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Python Signal Processing [![Awesome](https://awesome.re/badge.svg)](https://awesome.re) 2 | 3 | This repository contains tutorials on understanding and applying signal processing using NumPy and PyTorch. 4 | 5 | **splearn** is a package for signal processing and machine learning with Python. It is built on top of [NumPy](https://numpy.org) and [SciPy](https://www.scipy.org), to provide easy to use functions from common signal processing tasks to machine learning. 6 | 7 | ## Contents 8 | 9 | - [Tutorials](#tutorials) 10 | - [Getting Started](#getting-started) 11 | - [Datasets](#datasets) 12 | - [Methods](#methods) 13 | 14 | --- 15 | 16 | ## Tutorials 17 | 18 | Signal processing can be daunting; we aim to bridge the gap for anyone who are new signal processings to get started, check out the [tutorials](https://github.com/jinglescode/python-signal-processing/tree/main/tutorials) to get started on signal processings. 19 | 20 | ### 1. Signal composition (time, sampling rate and frequency) 21 | 22 | In order to begin the signal processing adventure, we need to understand what we are dealing with. In the first tutorial, we will uncover what is a signal, and what it is made up of. We will look at how the sampling rate and frequency can affect a signal. We will also see what happens when we combine multiple signals of different frequencies. 23 | 24 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jinglescode/python-signal-processing/blob/main/tutorials/Signal%20composition%20-%20time%2C%20sampling%20rate%20and%20frequency.ipynb) 25 | 26 | ### 2. Fourier Transform 27 | 28 | Now we know what are signals made of and we learned that combining multiple signals of various frequencies will jumbled up all the frequencies. In this tutorial, we will learn about Fourier Transform and how it can take a complex signal and decompose it to the frequencies that made it up. 29 | 30 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jinglescode/python-signal-processing/blob/main/tutorials/Fourier%20Transform.ipynb) 31 | 32 | ### 3. Denoising with mean-smooth filter 33 | 34 | We know that signals can be noisy, and this tutorial will focus on removing these noise. We learn to apply the simplest filter to perform denoising, the running mean filter. We will also understand what are edge effects. 35 | 36 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jinglescode/python-signal-processing/blob/main/tutorials/Denoising%20with%20mean-smooth%20filter.ipynb) 37 | 38 | ### 4. Denoising with Gaussian-smooth filter 39 | 40 | Next, we will look at a slight adaptation of the mean-smooth filter, the Gaussian smoothing filter. This tends to smooth the data to be a bit smoother compared to mean-smooth filter. This does not mean that one is better than the other, it depends on the specific applications. So, it is important to be aware of different filters type and how to use them. 41 | 42 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jinglescode/python-signal-processing/blob/main/tutorials/Denoising%20with%20Gaussian-smooth%20filter.ipynb) 43 | 44 | ### 5. Canonical correlation analysis 45 | 46 | Canonical correlation analysis (CCA) is applied to analyze the frequency components of a signal. In this tutorials, we use CCA for feature extraction and classification. 47 | 48 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jinglescode/python-signal-processing/blob/main/tutorials/Canonical%20Correlation%20Analysis.ipynb) 49 | 50 | ### 6. Task-related component analysis 51 | 52 | Task-related component analysis (TRCA) is a classification method originally for steady-state visual evoked potentials (SSVEPs) detection. 53 | 54 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jinglescode/python-signal-processing/blob/main/tutorials/Task-Related%20Component%20Analysis.ipynb) 55 | 56 | --- 57 | 58 | ## Getting Started 59 | 60 | ### Installation 61 | 62 | Currently, this has not been released. Use `git clone`, and install the dependencies: 63 | 64 | ``` 65 | git clone https://github.com/jinglescode/python-signal-processing.git 66 | pip install -r requirements.txt 67 | ``` 68 | 69 | Dependencies, see [requirements.txt](https://github.com/jinglescode/python-signal-processing/tree/main/requirements.txt). 70 | 71 | ### Usage 72 | 73 | Let's generate a 2D-signal, sampled at 100-Hz. Design and apply a 4th-order bandpass Butterworth filter with a cutoff frequency between 5-Hz and 20-Hz. 74 | 75 | ```python 76 | from splearn.data.generate import generate_signal 77 | from splearn.filter.butter import butter_bandpass 78 | 79 | signal_2d = generate_signal( 80 | length_seconds=4, 81 | sampling_rate=100, 82 | frequencies=[[4,7,11,17,40, 50],[1, 3]], 83 | plot=True 84 | ) 85 | 86 | signal_2d_filtered = butter_bandpass( 87 | signal=signal_2d, 88 | lowcut=5, 89 | highcut=20, 90 | sampling_rate=100, 91 | type='sos', 92 | order=4, 93 | plot=True, 94 | plot_xlim=[3,20] 95 | ) 96 | ``` 97 | 98 | See [examples](https://github.com/jinglescode/python-signal-processing/tree/main/examples) for more examples. 99 | 100 | --- 101 | 102 | ## Datasets 103 | 104 | Some dataset extraction code are provided, see [examples](https://github.com/jinglescode/python-signal-processing/tree/main/examples) to see how to use them. 105 | 106 | - [BETA: A Large Benchmark Database Toward SSVEP-BCI Application](https://www.frontiersin.org/articles/10.3389/fnins.2020.00627/full) 107 | - [A Benchmark Dataset for SSVEP-Based Brain–Computer Interfaces](https://ieeexplore.ieee.org/document/7740878) 108 | - [A Comparison Study of Canonical Correlation Analysis Based Methods for Detecting Steady-State Visual Evoked Potentials](http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0140703) 109 | - [EEG dataset and OpenBMI toolbox for three BCI paradigms: an investigation into BCI illiteracy.](https://academic.oup.com/gigascience/article/8/5/giz002/5304369) 110 | 111 | #### Disclaimer on Datasets 112 | 113 | We do not host or distribute these datasets, vouch for their quality or fairness, or claim that you have license to use the dataset. It is your responsibility to determine whether you have permission to use the dataset under the dataset's license. 114 | 115 | If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community! 116 | 117 | ## Methods 118 | 119 | A few methods have been included as they have been tested (within my limited time). They are by no means recommended or exhaustive. 120 | 121 | - [Compact-CNN](https://arxiv.org/pdf/1803.04566.pdf) 122 | - [Canonical Correlation Analysis (CCA)](http://en.wikipedia.org/wiki/Canonical_correlation) 123 | - [Task-Related Component Analysis (TRCA)](https://ieeexplore.ieee.org/document/7904641) 124 | - [Multi-Task SSVEP](https://jinglescode.github.io/ssvep-multi-task-learning/) 125 | - [Convolutional correlation analysis for enhancing the performance of SSVEP-based brain-computer interface](https://ieeexplore.ieee.org/abstract/document/9261605/) 126 | - [Time-domain-based CNN method (tCNN)](https://ieeexplore.ieee.org/abstract/document/9632600) 127 | 128 | If you wish to include your model, you are welcome to do so via pull request. Do check out how other models are implemented for reference. PyTorch models and sklearn/numpy methods only. 129 | -------------------------------------------------------------------------------- /examples/convca_jfpm.py: -------------------------------------------------------------------------------- 1 | # for running locally 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | # path = os.path.join(cwd, "..\\..\\") 6 | path = cwd 7 | sys.path.append(path) 8 | 9 | # imports 10 | import numpy as np 11 | import logging 12 | logging.getLogger('lightning').setLevel(0) 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | 16 | import torch 17 | from torch.utils.data import DataLoader 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | import pytorch_lightning 22 | pytorch_lightning.utilities.distributed.log.setLevel(logging.ERROR) 23 | from pytorch_lightning import Trainer, seed_everything 24 | from pytorch_lightning.callbacks import LearningRateMonitor 25 | from pytorch_lightning.loggers import TensorBoardLogger 26 | 27 | from splearn.data import MultipleSubjects, JFPM, PyTorchDataset2Views 28 | from splearn.utils import Logger, Config 29 | from splearn.filter.butterworth import butter_bandpass_filter 30 | from splearn.filter.notch import notch_filter 31 | from splearn.nn.models import ConvCA, ConvCaLighting 32 | from splearn.cross_decomposition.reference_frequencies import generate_reference_signals 33 | 34 | config = { 35 | "experiment_name": "convca_jfpm_nokfold", 36 | "data": { 37 | "load_subject_ids": np.arange(1,11), 38 | "root": "../data/jfpm", 39 | "duration": 1, 40 | }, 41 | "model": { 42 | "optimizer": "adamw", 43 | "scheduler": "cosine_with_warmup", 44 | }, 45 | "training": { 46 | "num_epochs": 100, 47 | "num_warmup_epochs": 20, 48 | "learning_rate": 0.03, 49 | "gpus": [0], 50 | "batchsize": 256, 51 | }, 52 | "testing": { 53 | "test_subject_ids": np.arange(1,11), 54 | "kfolds": np.arange(0,3), 55 | }, 56 | "seed": 1234 57 | } 58 | main_logger = Logger(filename_postfix=config["experiment_name"]) 59 | main_logger.write_to_log("Config") 60 | main_logger.write_to_log(config) 61 | config = Config(config) 62 | 63 | seed_everything(config.seed) 64 | 65 | # define custom preprocessing steps 66 | def func_preprocessing(data): 67 | data_x = data.data 68 | data_x = notch_filter(data_x, sampling_rate=data.sampling_rate, notch_freq=50.0) 69 | data_x = butter_bandpass_filter(data_x, lowcut=7, highcut=90, sampling_rate=data.sampling_rate, order=6) 70 | start_t = 35 71 | end_t = start_t + (config.data.duration * data.sampling_rate) 72 | data_x = data_x[:,:,:,start_t:end_t] 73 | data.set_data(data_x) 74 | 75 | 76 | # prepare data loader 77 | def leave_one_subject_out(data, **kwargs): 78 | test_subject_id = kwargs["test_subject_id"] if "test_subject_id" in kwargs else 1 79 | kfold_k = kwargs["kfold_k"] if "kfold_k" in kwargs else 0 80 | kfold_split = kwargs["kfold_split"] if "kfold_split" in kwargs else 3 81 | 82 | num_subjects = data.data.shape[0] 83 | num_trials = data.data.shape[1] 84 | num_channel = data.data.shape[2] 85 | size = data.data.shape[3] 86 | sampling_rate = data.sampling_rate 87 | target_frequencies = data.stimulus_frequencies 88 | 89 | ref = generate_reference_signals(target_frequencies, size, sampling_rate, num_harmonics=1) 90 | ref = ref[:, 0, :] 91 | ref = np.expand_dims(ref, axis=1) 92 | ref = np.repeat(ref, num_channel, axis=1) 93 | ref = np.transpose(ref, (1,2,0)) 94 | ref = np.expand_dims(ref, axis=0) 95 | 96 | # get test data 97 | test_sub_idx = np.where(data.subject_ids == test_subject_id)[0][0] 98 | selected_subject_data = data.data[test_sub_idx] 99 | selected_subject_targets = data.targets[test_sub_idx] 100 | selected_subject_ref = np.repeat(ref, selected_subject_data.shape[0], axis=0) 101 | test_dataset = PyTorchDataset2Views(selected_subject_data, selected_subject_ref, selected_subject_targets) 102 | 103 | # get train val data 104 | indices = np.arange(data.data.shape[0]) 105 | train_val_data = data.data[indices!=test_sub_idx, :, :, :] 106 | train_val_data = train_val_data.reshape((train_val_data.shape[0]*train_val_data.shape[1], train_val_data.shape[2], train_val_data.shape[3])) 107 | train_val_targets = data.targets[indices!=test_sub_idx, :] 108 | train_val_targets = train_val_targets.reshape((train_val_targets.shape[0]*train_val_targets.shape[1])) 109 | 110 | # train val split 111 | # (X_train, y_train), (X_val, y_val) = data.dataset_split_stratified(train_val_data, train_val_targets, k=kfold_k, n_splits=kfold_split) 112 | # train_ref = np.repeat(ref, X_train.shape[0], axis=0) 113 | # val_ref = np.repeat(ref, X_val.shape[0], axis=0) 114 | # train_dataset = PyTorchDataset2Views(X_train, train_ref, y_train) 115 | # val_dataset = PyTorchDataset2Views(X_val, val_ref, y_val) 116 | # return train_dataset, val_dataset, test_dataset 117 | 118 | # no kfold 119 | X_train = train_val_data 120 | train_ref = np.repeat(ref, X_train.shape[0], axis=0) 121 | y_train = train_val_targets 122 | train_dataset = PyTorchDataset2Views(X_train, train_ref, y_train) 123 | return train_dataset, test_dataset 124 | 125 | 126 | # load data 127 | data = MultipleSubjects( 128 | dataset=JFPM, 129 | root=os.path.join(path,config.data.root), 130 | subject_ids=config.data.load_subject_ids, 131 | func_preprocessing=func_preprocessing, 132 | func_get_train_val_test_dataset=leave_one_subject_out, 133 | verbose=True, 134 | ) 135 | 136 | num_channel = data.data.shape[2] 137 | num_classes = data.stimulus_frequencies.shape[0] 138 | signal_length = data.data.shape[3] 139 | 140 | ##### test data 141 | 142 | test_subject_id = 1 143 | train_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id) 144 | train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 145 | # val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False) 146 | test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 147 | 148 | # print() 149 | print("train_loader", train_loader.dataset.data_view1.shape, train_loader.dataset.data_view2.shape, train_loader.dataset.targets.shape) 150 | # print("val_loader", val_loader.dataset.data_view1.shape, val_loader.dataset.data_view2.shape, val_loader.dataset.targets.shape) 151 | print("test_loader", test_loader.dataset.data_view1.shape, test_loader.dataset.data_view2.shape, test_loader.dataset.targets.shape) 152 | 153 | ###### 154 | 155 | def train_test_subject_kfold(data, config, test_subject_id, kfold_k=0): 156 | 157 | ## init data 158 | # train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_k=kfold_k) 159 | # train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 160 | # val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False) 161 | # test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 162 | # no kfold 163 | train_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id) 164 | train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 165 | test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 166 | 167 | ## init model 168 | base_model = ConvCA(num_channel=num_channel, num_classes=num_classes, signal_length=signal_length) 169 | model = ConvCaLighting( 170 | optimizer=config.model.optimizer, 171 | scheduler=config.model.scheduler, 172 | optimizer_learning_rate=config.training.learning_rate, 173 | scheduler_warmup_epochs=config.training.num_warmup_epochs, 174 | ) 175 | model.build_model(model=base_model) 176 | 177 | ## train 178 | sub_dir = "sub"+ str(test_subject_id) +"_k"+ str(kfold_k) 179 | logger_tb = TensorBoardLogger(save_dir="tensorboard_logs", name=config.experiment_name, sub_dir=sub_dir) 180 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 181 | 182 | trainer = Trainer(max_epochs=config.training.num_epochs, gpus=config.training.gpus, logger=logger_tb, progress_bar_refresh_rate=0, weights_summary=None, callbacks=[lr_monitor]) 183 | # trainer.fit(model, train_loader, val_loader) 184 | trainer.fit(model, train_loader) 185 | 186 | ## test 187 | 188 | result = trainer.test(dataloaders=test_loader, verbose=False) 189 | test_acc = result[0]['test_acc_epoch'] 190 | 191 | return test_acc 192 | 193 | #### 194 | 195 | main_logger.write_to_log("Begin", break_line=True) 196 | 197 | test_results_acc = {} 198 | means = [] 199 | 200 | def k_fold_train_test_all_subjects(): 201 | 202 | for test_subject_id in config.testing.test_subject_ids: 203 | print() 204 | print("running test_subject_id:", test_subject_id) 205 | 206 | if test_subject_id not in test_results_acc: 207 | test_results_acc[test_subject_id] = [] 208 | 209 | # k-fold 210 | # for k in config.testing.kfolds: 211 | # test_acc = train_test_subject_kfold(data, config, test_subject_id, kfold_k=k) 212 | # test_results_acc[test_subject_id].append(test_acc) 213 | # mean_acc = np.mean(test_results_acc[test_subject_id]) 214 | # means.append(mean_acc) 215 | 216 | # one fold: 217 | mean_acc = train_test_subject_kfold(data, config, test_subject_id) 218 | means.append(mean_acc) 219 | 220 | this_result = { 221 | "test_subject_id": test_subject_id, 222 | "mean_acc": mean_acc, 223 | "acc": test_results_acc[test_subject_id], 224 | } 225 | print(this_result) 226 | main_logger.write_to_log(this_result) 227 | 228 | k_fold_train_test_all_subjects() 229 | 230 | mean_acc = np.mean(means) 231 | print() 232 | print("mean all", mean_acc) 233 | main_logger.write_to_log("Mean acc: "+str(mean_acc), break_line=True) 234 | -------------------------------------------------------------------------------- /examples/deepconvnet_jfpm.py: -------------------------------------------------------------------------------- 1 | # for running locally 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | # path = os.path.join(cwd, "..\\..\\") 6 | path = cwd 7 | sys.path.append(path) 8 | 9 | # imports 10 | import numpy as np 11 | import logging 12 | logging.getLogger('lightning').setLevel(0) 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | 16 | import torch 17 | from torch.utils.data import DataLoader 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | import pytorch_lightning 22 | pytorch_lightning.utilities.distributed.log.setLevel(logging.ERROR) 23 | from pytorch_lightning import Trainer, seed_everything 24 | from pytorch_lightning.callbacks import LearningRateMonitor 25 | from pytorch_lightning.loggers import TensorBoardLogger 26 | 27 | from splearn.data import MultipleSubjects, JFPM 28 | from splearn.utils import Logger, Config 29 | from splearn.filter.butterworth import butter_bandpass_filter 30 | from splearn.filter.notch import notch_filter 31 | from splearn.nn.models import DeepConvNet 32 | from splearn.nn.base import LightningModelClassifier 33 | 34 | config = { 35 | "experiment_name": "deepconvnet_jfpm_nokfold", 36 | "data": { 37 | "load_subject_ids": np.arange(1,11), 38 | "root": "../data/jfpm", 39 | "duration": 1, 40 | }, 41 | "model": { 42 | "optimizer": "adamw", 43 | "scheduler": "cosine_with_warmup", 44 | }, 45 | "training": { 46 | "num_epochs": 100, 47 | "num_warmup_epochs": 20, 48 | "learning_rate": 0.03, 49 | "gpus": [0], 50 | "batchsize": 256, 51 | }, 52 | "testing": { 53 | "test_subject_ids": np.arange(1,11), 54 | "kfolds": np.arange(0,3), 55 | }, 56 | "seed": 1234 57 | } 58 | main_logger = Logger(filename_postfix=config["experiment_name"]) 59 | main_logger.write_to_log("Config") 60 | main_logger.write_to_log(config) 61 | config = Config(config) 62 | 63 | seed_everything(config.seed) 64 | 65 | # define custom preprocessing steps 66 | def func_preprocessing(data): 67 | data_x = data.data 68 | data_x = notch_filter(data_x, sampling_rate=data.sampling_rate, notch_freq=50.0) 69 | data_x = butter_bandpass_filter(data_x, lowcut=7, highcut=90, sampling_rate=data.sampling_rate, order=6) 70 | start_t = 35 71 | end_t = start_t + (config.data.duration * data.sampling_rate) 72 | data_x = data_x[:,:,:,start_t:end_t] 73 | data.set_data(data_x) 74 | 75 | # load data 76 | data = MultipleSubjects( 77 | dataset=JFPM, 78 | root=os.path.join(path,config.data.root), 79 | subject_ids=config.data.load_subject_ids, 80 | func_preprocessing=func_preprocessing, 81 | verbose=True, 82 | ) 83 | 84 | num_channel = data.data.shape[2] 85 | num_classes = data.stimulus_frequencies.shape[0] 86 | signal_length = data.data.shape[3] 87 | 88 | 89 | test_subject_id = 1 90 | train_dataset, test_dataset = data.get_train_test_dataset(test_subject_id=test_subject_id) 91 | train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 92 | test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 93 | print("train_loader", train_loader.dataset.data.shape, train_loader.dataset.targets.shape) 94 | print("test_loader", test_loader.dataset.data.shape, test_loader.dataset.targets.shape) 95 | 96 | def train_test_subject_kfold(data, config, test_subject_id, kfold_k=0): 97 | 98 | ## init data 99 | # train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_k=kfold_k) 100 | # train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 101 | # val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False) 102 | # test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 103 | # no kfold 104 | train_dataset, test_dataset = data.get_train_test_dataset(test_subject_id=test_subject_id) 105 | train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 106 | test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 107 | 108 | ## init model 109 | 110 | base_model = DeepConvNet(num_channel=num_channel, num_classes=num_classes, signal_length=signal_length, pool_time_stride=1) 111 | 112 | model = LightningModelClassifier( 113 | optimizer=config.model.optimizer, 114 | scheduler=config.model.scheduler, 115 | optimizer_learning_rate=config.training.learning_rate, 116 | scheduler_warmup_epochs=config.training.num_warmup_epochs, 117 | ) 118 | 119 | model.build_model(model=base_model) 120 | 121 | ## train 122 | 123 | sub_dir = "sub"+ str(test_subject_id) +"_k"+ str(kfold_k) 124 | logger_tb = TensorBoardLogger(save_dir="tensorboard_logs", name=config.experiment_name, sub_dir=sub_dir) 125 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 126 | 127 | trainer = Trainer(max_epochs=config.training.num_epochs, gpus=config.training.gpus, logger=logger_tb, progress_bar_refresh_rate=0, weights_summary=None, callbacks=[lr_monitor]) 128 | # trainer.fit(model, train_loader, val_loader) 129 | trainer.fit(model, train_loader) 130 | 131 | ## test 132 | 133 | result = trainer.test(dataloaders=test_loader, verbose=False) 134 | test_acc = result[0]['test_acc_epoch'] 135 | 136 | return test_acc 137 | 138 | #### 139 | 140 | main_logger.write_to_log("Begin", break_line=True) 141 | 142 | test_results_acc = {} 143 | means = [] 144 | 145 | def k_fold_train_test_all_subjects(): 146 | 147 | for test_subject_id in config.testing.test_subject_ids: 148 | print() 149 | print("running test_subject_id:", test_subject_id) 150 | 151 | if test_subject_id not in test_results_acc: 152 | test_results_acc[test_subject_id] = [] 153 | 154 | # k-fold 155 | # for k in config.testing.kfolds: 156 | # test_acc = train_test_subject_kfold(data, config, test_subject_id, kfold_k=k) 157 | # test_results_acc[test_subject_id].append(test_acc) 158 | # mean_acc = np.mean(test_results_acc[test_subject_id]) 159 | # means.append(mean_acc) 160 | # one fold: 161 | mean_acc = train_test_subject_kfold(data, config, test_subject_id) 162 | means.append(mean_acc) 163 | 164 | this_result = { 165 | "test_subject_id": test_subject_id, 166 | "mean_acc": mean_acc, 167 | "acc": test_results_acc[test_subject_id], 168 | } 169 | print(this_result) 170 | main_logger.write_to_log(this_result) 171 | 172 | k_fold_train_test_all_subjects() 173 | 174 | mean_acc = np.mean(means) 175 | print() 176 | print("mean all", mean_acc) 177 | main_logger.write_to_log("Mean acc: "+str(mean_acc), break_line=True) 178 | -------------------------------------------------------------------------------- /examples/eegnet_benchmark.py: -------------------------------------------------------------------------------- 1 | # for running locally 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | # path = os.path.join(cwd, "..\\..\\") 6 | path = cwd 7 | sys.path.append(path) 8 | 9 | # imports 10 | import numpy as np 11 | import logging 12 | logging.getLogger('lightning').setLevel(0) 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | 16 | import torch 17 | from torch.utils.data import DataLoader 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | import pytorch_lightning 22 | pytorch_lightning.utilities.distributed.log.setLevel(logging.ERROR) 23 | from pytorch_lightning import Trainer, seed_everything 24 | from pytorch_lightning.callbacks import LearningRateMonitor 25 | from pytorch_lightning.loggers import TensorBoardLogger 26 | 27 | from splearn.data import MultipleSubjects, Benchmark 28 | from splearn.utils import Logger, Config 29 | from splearn.filter.butterworth import butter_bandpass_filter 30 | from splearn.filter.notch import notch_filter 31 | from splearn.filter.channels import pick_channels 32 | from splearn.nn.models import CompactEEGNet 33 | from splearn.nn.base import LightningModelClassifier 34 | 35 | config = { 36 | "experiment_name": "eegnet_benchmark_nokfold", 37 | "data": { 38 | "load_subject_ids": np.arange(1,36), 39 | "root": "../data/hsssvep", 40 | "selected_channels": ["PZ", "PO5", "PO3", "POz", "PO4", "PO6", "O1", "Oz", "O2"], 41 | "duration": 1, 42 | }, 43 | "model": { 44 | "optimizer": "adamw", 45 | "scheduler": "cosine_with_warmup", 46 | }, 47 | "training": { 48 | "num_epochs": 100, 49 | "num_warmup_epochs": 20, 50 | "learning_rate": 0.03, 51 | "gpus": [0], 52 | "batchsize": 256, 53 | }, 54 | "testing": { 55 | "test_subject_ids": np.arange(1,36), 56 | "kfolds": np.arange(0,3), 57 | }, 58 | "seed": 1234 59 | } 60 | main_logger = Logger(filename_postfix=config["experiment_name"]) 61 | main_logger.write_to_log("Config") 62 | main_logger.write_to_log(config) 63 | config = Config(config) 64 | 65 | seed_everything(config.seed) 66 | 67 | # define custom preprocessing steps 68 | def func_preprocessing(data): 69 | data_x = data.data 70 | data_x = pick_channels(data_x, channel_names=data.channel_names, selected_channels=config.data.selected_channels) 71 | data_x = notch_filter(data_x, sampling_rate=data.sampling_rate, notch_freq=50.0) 72 | data_x = butter_bandpass_filter(data_x, lowcut=7, highcut=90, sampling_rate=data.sampling_rate, order=6) 73 | start_t = 35 74 | end_t = start_t + (config.data.duration * data.sampling_rate) 75 | data_x = data_x[:,:,:,start_t:end_t] 76 | data.set_data(data_x) 77 | 78 | # load data 79 | data = MultipleSubjects( 80 | dataset=Benchmark, 81 | root=os.path.join(path,config.data.root), 82 | subject_ids=config.data.load_subject_ids, 83 | func_preprocessing=func_preprocessing, 84 | verbose=True, 85 | ) 86 | 87 | num_channel = data.data.shape[2] 88 | num_classes = data.stimulus_frequencies.shape[0] 89 | signal_length = data.data.shape[3] 90 | 91 | 92 | def train_test_subject_kfold(data, config, test_subject_id, kfold_k=0): 93 | 94 | ## init data 95 | # train_dataset, val_dataset, test_dataset = leave_one_subject_out(data, test_subject_id=test_subject_id, kfold_k=kfold_k) 96 | # train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_k=kfold_k) 97 | # train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 98 | # val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False) 99 | # test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 100 | # no kfold 101 | train_dataset, test_dataset = data.get_train_test_dataset(test_subject_id=test_subject_id) 102 | train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 103 | test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 104 | 105 | ## init model 106 | 107 | base_model = CompactEEGNet(num_channel=num_channel, num_classes=num_classes, signal_length=signal_length) 108 | 109 | model = LightningModelClassifier( 110 | optimizer=config.model.optimizer, 111 | scheduler=config.model.scheduler, 112 | optimizer_learning_rate=config.training.learning_rate, 113 | scheduler_warmup_epochs=config.training.num_warmup_epochs, 114 | ) 115 | 116 | model.build_model(model=base_model) 117 | 118 | ## train 119 | 120 | sub_dir = "sub"+ str(test_subject_id) +"_k"+ str(kfold_k) 121 | logger_tb = TensorBoardLogger(save_dir="tensorboard_logs", name=config.experiment_name, sub_dir=sub_dir) 122 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 123 | 124 | trainer = Trainer(max_epochs=config.training.num_epochs, gpus=config.training.gpus, logger=logger_tb, progress_bar_refresh_rate=0, weights_summary=None, callbacks=[lr_monitor]) 125 | # trainer.fit(model, train_loader, val_loader) 126 | trainer.fit(model, train_loader) 127 | 128 | ## test 129 | 130 | result = trainer.test(dataloaders=test_loader, verbose=False) 131 | test_acc = result[0]['test_acc_epoch'] 132 | 133 | return test_acc 134 | 135 | #### 136 | 137 | main_logger.write_to_log("Begin", break_line=True) 138 | 139 | test_results_acc = {} 140 | means = [] 141 | 142 | def k_fold_train_test_all_subjects(): 143 | 144 | for test_subject_id in config.testing.test_subject_ids: 145 | print() 146 | print("running test_subject_id:", test_subject_id) 147 | 148 | if test_subject_id not in test_results_acc: 149 | test_results_acc[test_subject_id] = [] 150 | 151 | # k-fold 152 | # for k in config.testing.kfolds: 153 | # test_acc = train_test_subject_kfold(data, config, test_subject_id, kfold_k=k) 154 | # test_results_acc[test_subject_id].append(test_acc) 155 | # mean_acc = np.mean(test_results_acc[test_subject_id]) 156 | # means.append(mean_acc) 157 | # one fold: 158 | mean_acc = train_test_subject_kfold(data, config, test_subject_id) 159 | means.append(mean_acc) 160 | 161 | this_result = { 162 | "test_subject_id": test_subject_id, 163 | "mean_acc": mean_acc, 164 | "acc": test_results_acc[test_subject_id], 165 | } 166 | print(this_result) 167 | main_logger.write_to_log(this_result) 168 | 169 | k_fold_train_test_all_subjects() 170 | 171 | mean_acc = np.mean(means) 172 | print() 173 | print("mean all", mean_acc) 174 | main_logger.write_to_log("Mean acc: "+str(mean_acc), break_line=True) 175 | -------------------------------------------------------------------------------- /examples/eegnet_beta.py: -------------------------------------------------------------------------------- 1 | # for running locally 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | # path = os.path.join(cwd, "..\\..\\") 6 | path = cwd 7 | sys.path.append(path) 8 | 9 | # imports 10 | import numpy as np 11 | import logging 12 | logging.getLogger('lightning').setLevel(0) 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | 16 | import torch 17 | from torch.utils.data import DataLoader 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | import pytorch_lightning 22 | pytorch_lightning.utilities.distributed.log.setLevel(logging.ERROR) 23 | from pytorch_lightning import Trainer, seed_everything 24 | from pytorch_lightning.callbacks import LearningRateMonitor 25 | from pytorch_lightning.loggers import TensorBoardLogger 26 | 27 | from splearn.data import MultipleSubjects, Beta 28 | from splearn.utils import Logger, Config 29 | from splearn.filter.butterworth import butter_bandpass_filter 30 | from splearn.filter.notch import notch_filter 31 | from splearn.filter.channels import pick_channels 32 | from splearn.nn.models import CompactEEGNet 33 | from splearn.nn.base import LightningModelClassifier 34 | 35 | config = { 36 | "experiment_name": "eegnet_beta_nokfold", 37 | "data": { 38 | "load_subject_ids": np.arange(1,71), 39 | "root": "../data/beta", 40 | "selected_channels": ["PZ","PO3","PO5","PO4","PO6","POZ","O1","OZ","O2"], 41 | "duration": 1, 42 | }, 43 | "model": { 44 | "optimizer": "adamw", 45 | "scheduler": "cosine_with_warmup", 46 | }, 47 | "training": { 48 | "num_epochs": 100, 49 | "num_warmup_epochs": 20, 50 | "learning_rate": 0.03, 51 | "gpus": [0], 52 | "batchsize": 256, 53 | }, 54 | "testing": { 55 | "test_subject_ids": np.arange(1,71), 56 | "kfolds": np.arange(0,3), 57 | }, 58 | "seed": 1234 59 | } 60 | main_logger = Logger(filename_postfix=config["experiment_name"]) 61 | main_logger.write_to_log("Config") 62 | main_logger.write_to_log(config) 63 | config = Config(config) 64 | 65 | seed_everything(config.seed) 66 | 67 | # define custom preprocessing steps 68 | def func_preprocessing(data): 69 | data_x = data.data 70 | data_x = pick_channels(data_x, channel_names=data.channel_names, selected_channels=config.data.selected_channels) 71 | data_x = notch_filter(data_x, sampling_rate=data.sampling_rate, notch_freq=50.0) 72 | data_x = butter_bandpass_filter(data_x, lowcut=7, highcut=90, sampling_rate=data.sampling_rate, order=6) 73 | start_t = 35 74 | end_t = start_t + (config.data.duration * data.sampling_rate) 75 | data_x = data_x[:,:,:,start_t:end_t] 76 | data.set_data(data_x) 77 | 78 | # load data 79 | data = MultipleSubjects( 80 | dataset=Beta, 81 | root=os.path.join(path,config.data.root), 82 | subject_ids=config.data.load_subject_ids, 83 | func_preprocessing=func_preprocessing, 84 | verbose=True, 85 | ) 86 | 87 | num_channel = data.data.shape[2] 88 | num_classes = data.stimulus_frequencies.shape[0] 89 | signal_length = data.data.shape[3] 90 | 91 | 92 | def train_test_subject_kfold(data, config, test_subject_id, kfold_k=0): 93 | 94 | ## init data 95 | # train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_k=kfold_k) 96 | # train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 97 | # val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False) 98 | # test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 99 | # no kfold 100 | train_dataset, test_dataset = data.get_train_test_dataset(test_subject_id=test_subject_id) 101 | train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 102 | test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 103 | 104 | ## init model 105 | 106 | base_model = CompactEEGNet(num_channel=num_channel, num_classes=num_classes, signal_length=signal_length) 107 | 108 | model = LightningModelClassifier( 109 | optimizer=config.model.optimizer, 110 | scheduler=config.model.scheduler, 111 | optimizer_learning_rate=config.training.learning_rate, 112 | scheduler_warmup_epochs=config.training.num_warmup_epochs, 113 | ) 114 | 115 | model.build_model(model=base_model) 116 | 117 | ## train 118 | 119 | sub_dir = "sub"+ str(test_subject_id) +"_k"+ str(kfold_k) 120 | logger_tb = TensorBoardLogger(save_dir="tensorboard_logs", name=config.experiment_name, sub_dir=sub_dir) 121 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 122 | 123 | trainer = Trainer(max_epochs=config.training.num_epochs, gpus=config.training.gpus, logger=logger_tb, progress_bar_refresh_rate=0, weights_summary=None, callbacks=[lr_monitor]) 124 | # trainer.fit(model, train_loader, val_loader) 125 | trainer.fit(model, train_loader) 126 | 127 | ## test 128 | 129 | result = trainer.test(dataloaders=test_loader, verbose=False) 130 | test_acc = result[0]['test_acc_epoch'] 131 | 132 | return test_acc 133 | 134 | #### 135 | 136 | main_logger.write_to_log("Begin", break_line=True) 137 | 138 | test_results_acc = {} 139 | means = [] 140 | 141 | def k_fold_train_test_all_subjects(): 142 | 143 | for test_subject_id in config.testing.test_subject_ids: 144 | print() 145 | print("running test_subject_id:", test_subject_id) 146 | 147 | if test_subject_id not in test_results_acc: 148 | test_results_acc[test_subject_id] = [] 149 | 150 | # k-fold 151 | # for k in config.testing.kfolds: 152 | # test_acc = train_test_subject_kfold(data, config, test_subject_id, kfold_k=k) 153 | # test_results_acc[test_subject_id].append(test_acc) 154 | # mean_acc = np.mean(test_results_acc[test_subject_id]) 155 | # means.append(mean_acc) 156 | # one fold: 157 | mean_acc = train_test_subject_kfold(data, config, test_subject_id) 158 | means.append(mean_acc) 159 | 160 | this_result = { 161 | "test_subject_id": test_subject_id, 162 | "mean_acc": mean_acc, 163 | "acc": test_results_acc[test_subject_id], 164 | } 165 | print(this_result) 166 | main_logger.write_to_log(this_result) 167 | 168 | k_fold_train_test_all_subjects() 169 | 170 | mean_acc = np.mean(means) 171 | print() 172 | print("mean all", mean_acc) 173 | main_logger.write_to_log("Mean acc: "+str(mean_acc), break_line=True) 174 | -------------------------------------------------------------------------------- /examples/eegnet_jfpm.py: -------------------------------------------------------------------------------- 1 | # for running locally 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | # path = os.path.join(cwd, "..\\..\\") 6 | path = cwd 7 | sys.path.append(path) 8 | 9 | # imports 10 | import numpy as np 11 | import logging 12 | logging.getLogger('lightning').setLevel(0) 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | 16 | import torch 17 | from torch.utils.data import DataLoader 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | import pytorch_lightning 22 | pytorch_lightning.utilities.distributed.log.setLevel(logging.ERROR) 23 | from pytorch_lightning import Trainer, seed_everything 24 | from pytorch_lightning.callbacks import LearningRateMonitor 25 | from pytorch_lightning.loggers import TensorBoardLogger 26 | 27 | from splearn.data import MultipleSubjects, JFPM 28 | from splearn.utils import Logger, Config 29 | from splearn.filter.butterworth import butter_bandpass_filter 30 | from splearn.filter.notch import notch_filter 31 | from splearn.nn.models import CompactEEGNet 32 | from splearn.nn.base import LightningModelClassifier 33 | 34 | config = { 35 | "experiment_name": "eegnet_jfpm_nokfold", 36 | "data": { 37 | "load_subject_ids": np.arange(1,11), 38 | "root": "../data/jfpm", 39 | "duration": 1, 40 | }, 41 | "model": { 42 | "optimizer": "adamw", 43 | "scheduler": "cosine_with_warmup", 44 | }, 45 | "training": { 46 | "num_epochs": 100, 47 | "num_warmup_epochs": 20, 48 | "learning_rate": 0.03, 49 | "gpus": [0], 50 | "batchsize": 256, 51 | }, 52 | "testing": { 53 | "test_subject_ids": np.arange(1,11), 54 | "kfolds": np.arange(0,3), 55 | }, 56 | "seed": 1234 57 | } 58 | main_logger = Logger(filename_postfix=config["experiment_name"]) 59 | main_logger.write_to_log("Config") 60 | main_logger.write_to_log(config) 61 | config = Config(config) 62 | 63 | seed_everything(config.seed) 64 | 65 | # define custom preprocessing steps 66 | def func_preprocessing(data): 67 | data_x = data.data 68 | data_x = notch_filter(data_x, sampling_rate=data.sampling_rate, notch_freq=50.0) 69 | data_x = butter_bandpass_filter(data_x, lowcut=7, highcut=90, sampling_rate=data.sampling_rate, order=6) 70 | start_t = 35 71 | end_t = start_t + (config.data.duration * data.sampling_rate) 72 | data_x = data_x[:,:,:,start_t:end_t] 73 | data.set_data(data_x) 74 | 75 | # load data 76 | data = MultipleSubjects( 77 | dataset=JFPM, 78 | root=os.path.join(path,config.data.root), 79 | subject_ids=config.data.load_subject_ids, 80 | func_preprocessing=func_preprocessing, 81 | verbose=True, 82 | ) 83 | 84 | num_channel = data.data.shape[2] 85 | num_classes = data.stimulus_frequencies.shape[0] 86 | signal_length = data.data.shape[3] 87 | 88 | 89 | test_subject_id = 1 90 | train_dataset, test_dataset = data.get_train_test_dataset(test_subject_id=test_subject_id) 91 | train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 92 | test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 93 | print("train_loader", train_loader.dataset.data.shape, train_loader.dataset.targets.shape) 94 | print("test_loader", test_loader.dataset.data.shape, test_loader.dataset.targets.shape) 95 | 96 | def train_test_subject_kfold(data, config, test_subject_id, kfold_k=0): 97 | 98 | ## init data 99 | # train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_k=kfold_k) 100 | # train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 101 | # val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False) 102 | # test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 103 | # no kfold 104 | train_dataset, test_dataset = data.get_train_test_dataset(test_subject_id=test_subject_id) 105 | train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 106 | test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 107 | 108 | ## init model 109 | 110 | base_model = CompactEEGNet(num_channel=num_channel, num_classes=num_classes, signal_length=signal_length) 111 | 112 | model = LightningModelClassifier( 113 | optimizer=config.model.optimizer, 114 | scheduler=config.model.scheduler, 115 | optimizer_learning_rate=config.training.learning_rate, 116 | scheduler_warmup_epochs=config.training.num_warmup_epochs, 117 | ) 118 | 119 | model.build_model(model=base_model) 120 | 121 | ## train 122 | 123 | sub_dir = "sub"+ str(test_subject_id) +"_k"+ str(kfold_k) 124 | logger_tb = TensorBoardLogger(save_dir="tensorboard_logs", name=config.experiment_name, sub_dir=sub_dir) 125 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 126 | 127 | trainer = Trainer(max_epochs=config.training.num_epochs, gpus=config.training.gpus, logger=logger_tb, progress_bar_refresh_rate=0, weights_summary=None, callbacks=[lr_monitor]) 128 | # trainer.fit(model, train_loader, val_loader) 129 | trainer.fit(model, train_loader) 130 | 131 | ## test 132 | 133 | result = trainer.test(dataloaders=test_loader, verbose=False) 134 | test_acc = result[0]['test_acc_epoch'] 135 | 136 | return test_acc 137 | 138 | #### 139 | 140 | main_logger.write_to_log("Begin", break_line=True) 141 | 142 | test_results_acc = {} 143 | means = [] 144 | 145 | def k_fold_train_test_all_subjects(): 146 | 147 | for test_subject_id in config.testing.test_subject_ids: 148 | print() 149 | print("running test_subject_id:", test_subject_id) 150 | 151 | if test_subject_id not in test_results_acc: 152 | test_results_acc[test_subject_id] = [] 153 | 154 | # k-fold 155 | # for k in config.testing.kfolds: 156 | # test_acc = train_test_subject_kfold(data, config, test_subject_id, kfold_k=k) 157 | # test_results_acc[test_subject_id].append(test_acc) 158 | # mean_acc = np.mean(test_results_acc[test_subject_id]) 159 | # means.append(mean_acc) 160 | # one fold: 161 | mean_acc = train_test_subject_kfold(data, config, test_subject_id) 162 | means.append(mean_acc) 163 | 164 | this_result = { 165 | "test_subject_id": test_subject_id, 166 | "mean_acc": mean_acc, 167 | "acc": test_results_acc[test_subject_id], 168 | } 169 | print(this_result) 170 | main_logger.write_to_log(this_result) 171 | 172 | k_fold_train_test_all_subjects() 173 | 174 | mean_acc = np.mean(means) 175 | print() 176 | print("mean all", mean_acc) 177 | main_logger.write_to_log("Mean acc: "+str(mean_acc), break_line=True) 178 | -------------------------------------------------------------------------------- /examples/fbcca.py: -------------------------------------------------------------------------------- 1 | # for running locally (may remove this if your path is right) 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | path = cwd 6 | sys.path.append(path) 7 | 8 | # # imports 9 | from splearn.data import Benchmark 10 | from splearn.cross_decomposition.fbcca import fbcca, fbcca_realtime 11 | 12 | 13 | 14 | # # config 15 | # load_subject_id = 1 16 | # path_to_dataset = "../data/hsssvep" 17 | 18 | # # load data 19 | # subject_dataset = Benchmark(root=path_to_dataset, subject_id=load_subject_id) 20 | # print(subject_dataset.data.shape) 21 | 22 | 23 | # eeg = subject_dataset.data[:, :, 250:500] 24 | # fs = 250 25 | # list_freqs = subject_dataset.stimulus_frequencies 26 | # print("list_freqs", list_freqs) 27 | 28 | # # results = fbcca(eeg, list_freqs, fs, num_harms=3, num_fbs=5) 29 | # results = fbcca_realtime(eeg, list_freqs, fs, num_harms=3, num_fbs=5) 30 | # print(results) 31 | 32 | 33 | import numpy as np 34 | 35 | SAMPLE_RATE = 500 36 | t = np.linspace(0,1, num=SAMPLE_RATE) 37 | s = np.sin(2*np.pi*11*t) 38 | s = s[np.newaxis,:] 39 | ss = np.repeat(s, 32, axis=0) 40 | sss = ss[np.newaxis,:] 41 | sss = np.repeat(sss, 8, axis=0) 42 | list_freqs = np.arange(8.0,13.0+1,1) 43 | print("sss", sss.shape) 44 | print("list_freqs", list_freqs.shape) 45 | results = fbcca(sss, list_freqs, SAMPLE_RATE, num_harms=3, num_fbs=5) 46 | print(results) 47 | -------------------------------------------------------------------------------- /examples/load_benchmark_data.py: -------------------------------------------------------------------------------- 1 | # for running locally (may remove this if your path is right) 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | path = cwd 6 | sys.path.append(path) 7 | 8 | # imports 9 | from splearn.data import Benchmark 10 | 11 | # config 12 | load_subject_id = 1 13 | path_to_dataset = "../data/hsssvep" 14 | 15 | # load data 16 | subject_dataset = Benchmark(root=path_to_dataset, subject_id=load_subject_id) 17 | 18 | # display 19 | print("About the data:") 20 | print("sample rate:", subject_dataset.sample_rate) 21 | print("data shape:", subject_dataset.data.shape) 22 | print("targets shape:", subject_dataset.targets.shape) 23 | print("stimulus frequencies:", subject_dataset.stimulus_frequencies) 24 | print("targets frequencies:", subject_dataset.targets_frequencies) 25 | print("targets:", subject_dataset.targets) 26 | 27 | # expected output: 28 | # ``` 29 | # About the data: 30 | # sample rate: 1000 31 | # data shape: (240, 64, 1500) 32 | # targets shape: (240,) 33 | # stimulus frequencies: [ 8. 9. 10. 11. 12. 13. 14. 15. 8.2 9.2 10.2 11.2 12.2 13.2 34 | # 14.2 15.2 8.4 9.4 10.4 11.4 12.4 13.4 14.4 15.4 8.6 9.6 10.6 11.6 35 | # 12.6 13.6 14.6 15.6 8.8 9.8 10.8 11.8 12.8 13.8 14.8 15.8] 36 | # targets frequencies: [ 8. 8. 8. 8. 8. 8. 9. 9. 9. 9. 9. 9. 10. 10. 37 | # 10. 10. 10. 10. 11. 11. 11. 11. 11. 11. 12. 12. 12. 12. 38 | # 12. 12. 13. 13. 13. 13. 13. 13. 14. 14. 14. 14. 14. 14. 39 | # 15. 15. 15. 15. 15. 15. 8.2 8.2 8.2 8.2 8.2 8.2 9.2 9.2 40 | # 9.2 9.2 9.2 9.2 10.2 10.2 10.2 10.2 10.2 10.2 11.2 11.2 11.2 11.2 41 | # 11.2 11.2 12.2 12.2 12.2 12.2 12.2 12.2 13.2 13.2 13.2 13.2 13.2 13.2 42 | # 14.2 14.2 14.2 14.2 14.2 14.2 15.2 15.2 15.2 15.2 15.2 15.2 8.4 8.4 43 | # 8.4 8.4 8.4 8.4 9.4 9.4 9.4 9.4 9.4 9.4 10.4 10.4 10.4 10.4 44 | # 10.4 10.4 11.4 11.4 11.4 11.4 11.4 11.4 12.4 12.4 12.4 12.4 12.4 12.4 45 | # 13.4 13.4 13.4 13.4 13.4 13.4 14.4 14.4 14.4 14.4 14.4 14.4 15.4 15.4 46 | # 15.4 15.4 15.4 15.4 8.6 8.6 8.6 8.6 8.6 8.6 9.6 9.6 9.6 9.6 47 | # 9.6 9.6 10.6 10.6 10.6 10.6 10.6 10.6 11.6 11.6 11.6 11.6 11.6 11.6 48 | # 12.6 12.6 12.6 12.6 12.6 12.6 13.6 13.6 13.6 13.6 13.6 13.6 14.6 14.6 49 | # 14.6 14.6 14.6 14.6 15.6 15.6 15.6 15.6 15.6 15.6 8.8 8.8 8.8 8.8 50 | # 8.8 8.8 9.8 9.8 9.8 9.8 9.8 9.8 10.8 10.8 10.8 10.8 10.8 10.8 51 | # 11.8 11.8 11.8 11.8 11.8 11.8 12.8 12.8 12.8 12.8 12.8 12.8 13.8 13.8 52 | # 13.8 13.8 13.8 13.8 14.8 14.8 14.8 14.8 14.8 14.8 15.8 15.8 15.8 15.8 53 | # 15.8 15.8] 54 | # targets: [ 0 0 0 0 0 0 1 1 1 1 1 1 2 2 2 2 2 2 3 3 3 3 3 3 55 | # 4 4 4 4 4 4 5 5 5 5 5 5 6 6 6 6 6 6 7 7 7 7 7 7 56 | # 8 8 8 8 8 8 9 9 9 9 9 9 10 10 10 10 10 10 11 11 11 11 11 11 57 | # 12 12 12 12 12 12 13 13 13 13 13 13 14 14 14 14 14 14 15 15 15 15 15 15 58 | # 16 16 16 16 16 16 17 17 17 17 17 17 18 18 18 18 18 18 19 19 19 19 19 19 59 | # 20 20 20 20 20 20 21 21 21 21 21 21 22 22 22 22 22 22 23 23 23 23 23 23 60 | # 24 24 24 24 24 24 25 25 25 25 25 25 26 26 26 26 26 26 27 27 27 27 27 27 61 | # 28 28 28 28 28 28 29 29 29 29 29 29 30 30 30 30 30 30 31 31 31 31 31 31 62 | # 32 32 32 32 32 32 33 33 33 33 33 33 34 34 34 34 34 34 35 35 35 35 35 35 63 | # 36 36 36 36 36 36 37 37 37 37 37 37 38 38 38 38 38 38 39 39 39 39 39 39] 64 | # ``` -------------------------------------------------------------------------------- /examples/load_benchmark_dataset.py: -------------------------------------------------------------------------------- 1 | # for running locally (may remove this if your path is right) 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | path = cwd 6 | sys.path.append(path) 7 | 8 | # imports 9 | import numpy as np 10 | from torch.utils.data import DataLoader 11 | 12 | from splearn.data import MultipleSubjects, Benchmark 13 | from splearn.utils import Config 14 | from splearn.filter.butterworth import butter_bandpass_filter 15 | from splearn.filter.notch import notch_filter 16 | from splearn.filter.channels import pick_channels 17 | 18 | # config 19 | 20 | config = { 21 | "data": { 22 | "load_subject_ids": np.arange(1,4), # get subject #1, #2 and #3 23 | "root": "../data/hsssvep", 24 | "selected_channels": ["PZ", "PO5", "PO3", "POz", "PO4", "PO6", "O1", "Oz", "O2"], 25 | }, 26 | "training": { 27 | "batchsize": 256, 28 | }, 29 | } 30 | config = Config(config) 31 | 32 | # define custom preprocessing steps 33 | def func_preprocessing(data): 34 | data_x = data.data 35 | data_x = pick_channels(data_x, channel_names=data.channel_names, selected_channels=config.data.selected_channels) 36 | data_x = notch_filter(data_x, sampling_rate=data.sampling_rate, notch_freq=50.0) 37 | data_x = butter_bandpass_filter(data_x, lowcut=7, highcut=90, sampling_rate=data.sampling_rate, order=6) 38 | start_t = 35 39 | end_t = start_t + 250 40 | data_x = data_x[:,:,:,start_t:end_t] 41 | data.set_data(data_x) 42 | 43 | # load data 44 | data = MultipleSubjects( 45 | dataset=Benchmark, 46 | root=os.path.join(path,config.data.root), 47 | subject_ids=config.data.load_subject_ids, 48 | func_preprocessing=func_preprocessing, 49 | verbose=True, 50 | ) 51 | 52 | # display data info 53 | num_channel = data.data.shape[2] 54 | signal_length = data.data.shape[3] 55 | print("Final data shape:", data.data.shape) 56 | print("num of subjects", data.data.shape[0]) 57 | print("num channels: ", num_channel) 58 | print("signal length: ", signal_length) 59 | 60 | def prepare_dataloaders(test_subject_id, kfold_split=3, kfold_k=0): 61 | 62 | train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_split=kfold_split, kfold_k=kfold_k) 63 | train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 64 | val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False) 65 | test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 66 | 67 | print("train_loader shape", train_loader.dataset.data.shape) 68 | print("val_loader shape", val_loader.dataset.data.shape) 69 | print("test_loader shape", test_loader.dataset.data.shape) 70 | 71 | test_subject_id = 1 72 | prepare_dataloaders(test_subject_id) 73 | -------------------------------------------------------------------------------- /examples/load_beta_data.py: -------------------------------------------------------------------------------- 1 | # for running locally (may remove this if your path is right) 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | path = cwd 6 | sys.path.append(path) 7 | 8 | # imports 9 | from splearn.data import Beta 10 | 11 | # config 12 | load_subject_id = 1 13 | path_to_dataset = "../data/beta" 14 | 15 | # load data 16 | subject_dataset = Beta(root=path_to_dataset, subject_id=load_subject_id, verbose=True) 17 | 18 | # # display 19 | print("About the data:") 20 | print("sample rate:", subject_dataset.sample_rate) 21 | print("data shape:", subject_dataset.data.shape) 22 | print("targets shape:", subject_dataset.targets.shape) 23 | print("stimulus frequencies:", subject_dataset.stimulus_frequencies) 24 | print("targets frequencies:", subject_dataset.targets_frequencies) 25 | print("targets:", subject_dataset.targets) 26 | print("channel_names", subject_dataset.channel_names) 27 | -------------------------------------------------------------------------------- /examples/load_jfpm_data.py: -------------------------------------------------------------------------------- 1 | # for running locally (may remove this if your path is right) 2 | import os 3 | from re import sub 4 | cwd = os.getcwd() 5 | import sys 6 | path = cwd 7 | sys.path.append(path) 8 | 9 | import scipy.io as sio 10 | import numpy as np 11 | from scipy.signal import butter, filtfilt 12 | 13 | from splearn.data import JFPM 14 | 15 | # config 16 | load_subject_id = 1 17 | path_to_dataset = "../data/jfpm" 18 | 19 | # load data 20 | subject_dataset = JFPM(root=path_to_dataset, subject_id=load_subject_id) 21 | print(subject_dataset.data.shape) 22 | print(subject_dataset.targets.shape) 23 | print(subject_dataset.sampling_rate) 24 | -------------------------------------------------------------------------------- /examples/multitask_benchmark.py: -------------------------------------------------------------------------------- 1 | # for running locally 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | # path = os.path.join(cwd, "..\\..\\") 6 | path = cwd 7 | sys.path.append(path) 8 | 9 | # imports 10 | import numpy as np 11 | import logging 12 | logging.getLogger('lightning').setLevel(0) 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | 16 | import torch 17 | from torch.utils.data import DataLoader 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | import pytorch_lightning 22 | pytorch_lightning.utilities.distributed.log.setLevel(logging.ERROR) 23 | from pytorch_lightning import Trainer, seed_everything 24 | from pytorch_lightning.callbacks import LearningRateMonitor 25 | from pytorch_lightning.loggers import TensorBoardLogger 26 | 27 | from splearn.data import MultipleSubjects, Benchmark 28 | from splearn.utils import Logger, Config 29 | from splearn.filter.butterworth import butter_bandpass_filter 30 | from splearn.filter.notch import notch_filter 31 | from splearn.filter.channels import pick_channels 32 | from splearn.nn.models import MultitaskSSVEPClassifier 33 | from splearn.nn.base import LightningModelClassifier 34 | 35 | config = { 36 | "experiment_name": "multitask_benchmark_nokfold", 37 | "data": { 38 | "load_subject_ids": np.arange(1,36), 39 | "root": "../data/hsssvep", 40 | "selected_channels": ["PZ", "PO5", "PO3", "POz", "PO4", "PO6", "O1", "Oz", "O2"], 41 | "duration": 1, 42 | }, 43 | "model": { 44 | "optimizer": "adamw", 45 | "scheduler": "cosine_with_warmup", 46 | }, 47 | "training": { 48 | "num_epochs": 100, 49 | "num_warmup_epochs": 20, 50 | "learning_rate": 0.03, 51 | "gpus": [0], 52 | "batchsize": 256, 53 | }, 54 | "testing": { 55 | "test_subject_ids": np.arange(1,36), 56 | "kfolds": np.arange(0,3), 57 | }, 58 | "seed": 1234 59 | } 60 | main_logger = Logger(filename_postfix=config["experiment_name"]) 61 | main_logger.write_to_log("Config") 62 | main_logger.write_to_log(config) 63 | config = Config(config) 64 | 65 | seed_everything(config.seed) 66 | 67 | # define custom preprocessing steps 68 | def func_preprocessing(data): 69 | data_x = data.data 70 | data_x = pick_channels(data_x, channel_names=data.channel_names, selected_channels=config.data.selected_channels) 71 | data_x = notch_filter(data_x, sampling_rate=data.sampling_rate, notch_freq=50.0) 72 | data_x = butter_bandpass_filter(data_x, lowcut=7, highcut=90, sampling_rate=data.sampling_rate, order=6) 73 | start_t = 35 74 | end_t = start_t + (config.data.duration * data.sampling_rate) 75 | data_x = data_x[:,:,:,start_t:end_t] 76 | data.set_data(data_x) 77 | 78 | # load data 79 | data = MultipleSubjects( 80 | dataset=Benchmark, 81 | root=os.path.join(path,config.data.root), 82 | subject_ids=config.data.load_subject_ids, 83 | func_preprocessing=func_preprocessing, 84 | verbose=True, 85 | ) 86 | 87 | num_channel = data.data.shape[2] 88 | num_classes = data.stimulus_frequencies.shape[0] 89 | signal_length = data.data.shape[3] 90 | 91 | 92 | def train_test_subject_kfold(data, config, test_subject_id, kfold_k=0): 93 | 94 | ## init data 95 | # train_dataset, val_dataset, test_dataset = leave_one_subject_out(data, test_subject_id=test_subject_id, kfold_k=kfold_k) 96 | # train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_k=kfold_k) 97 | # train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 98 | # val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False) 99 | # test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 100 | # no kfold 101 | train_dataset, test_dataset = data.get_train_test_dataset(test_subject_id=test_subject_id) 102 | train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 103 | test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 104 | loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 105 | 106 | ## init model 107 | 108 | eegnet = MultitaskSSVEPClassifier(num_channel=num_channel, num_classes=num_classes, signal_length=signal_length) 109 | 110 | model = LightningModelClassifier( 111 | optimizer=config.model.optimizer, 112 | scheduler=config.model.scheduler, 113 | optimizer_learning_rate=config.training.learning_rate, 114 | scheduler_warmup_epochs=config.training.num_warmup_epochs, 115 | ) 116 | 117 | model.build_model(model=eegnet) 118 | 119 | ## train 120 | 121 | sub_dir = "sub"+ str(test_subject_id) +"_k"+ str(kfold_k) 122 | logger_tb = TensorBoardLogger(save_dir="tensorboard_logs", name=config.experiment_name, sub_dir=sub_dir) 123 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 124 | 125 | trainer = Trainer(max_epochs=config.training.num_epochs, gpus=config.training.gpus, logger=logger_tb, progress_bar_refresh_rate=0, weights_summary=None, callbacks=[lr_monitor]) 126 | # trainer.fit(model, train_loader, val_loader) 127 | trainer.fit(model, train_loader) 128 | 129 | ## test 130 | 131 | result = trainer.test(dataloaders=test_loader, verbose=False) 132 | test_acc = result[0]['test_acc_epoch'] 133 | 134 | return test_acc 135 | 136 | #### 137 | 138 | main_logger.write_to_log("Begin", break_line=True) 139 | 140 | test_results_acc = {} 141 | means = [] 142 | 143 | def k_fold_train_test_all_subjects(): 144 | 145 | for test_subject_id in config.testing.test_subject_ids: 146 | print() 147 | print("running test_subject_id:", test_subject_id) 148 | 149 | if test_subject_id not in test_results_acc: 150 | test_results_acc[test_subject_id] = [] 151 | 152 | # k-fold 153 | # for k in config.testing.kfolds: 154 | # test_acc = train_test_subject_kfold(data, config, test_subject_id, kfold_k=k) 155 | # test_results_acc[test_subject_id].append(test_acc) 156 | # mean_acc = np.mean(test_results_acc[test_subject_id]) 157 | # means.append(mean_acc) 158 | # one fold: 159 | mean_acc = train_test_subject_kfold(data, config, test_subject_id) 160 | means.append(mean_acc) 161 | 162 | this_result = { 163 | "test_subject_id": test_subject_id, 164 | "mean_acc": mean_acc, 165 | "acc": test_results_acc[test_subject_id], 166 | } 167 | print(this_result) 168 | main_logger.write_to_log(this_result) 169 | 170 | k_fold_train_test_all_subjects() 171 | 172 | mean_acc = np.mean(means) 173 | print() 174 | print("mean all", mean_acc) 175 | main_logger.write_to_log("Mean acc: "+str(mean_acc), break_line=True) 176 | -------------------------------------------------------------------------------- /examples/multitask_beta.py: -------------------------------------------------------------------------------- 1 | # for running locally 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | # path = os.path.join(cwd, "..\\..\\") 6 | path = cwd 7 | sys.path.append(path) 8 | 9 | # imports 10 | import numpy as np 11 | import logging 12 | logging.getLogger('lightning').setLevel(0) 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | 16 | import torch 17 | from torch.utils.data import DataLoader 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | import pytorch_lightning 22 | pytorch_lightning.utilities.distributed.log.setLevel(logging.ERROR) 23 | from pytorch_lightning import Trainer, seed_everything 24 | from pytorch_lightning.callbacks import LearningRateMonitor 25 | from pytorch_lightning.loggers import TensorBoardLogger 26 | 27 | from splearn.data import MultipleSubjects, Beta 28 | from splearn.utils import Logger, Config 29 | from splearn.filter.butterworth import butter_bandpass_filter 30 | from splearn.filter.notch import notch_filter 31 | from splearn.filter.channels import pick_channels 32 | from splearn.nn.models import MultitaskSSVEPClassifier 33 | from splearn.nn.base import LightningModelClassifier 34 | 35 | config = { 36 | "experiment_name": "multitask_beta_nokfold", 37 | "data": { 38 | "load_subject_ids": np.arange(1,71), 39 | "root": "../data/beta", 40 | "selected_channels": ["PZ","PO3","PO5","PO4","PO6","POZ","O1","OZ","O2"], 41 | "duration": 1, 42 | }, 43 | "model": { 44 | "optimizer": "adamw", 45 | "scheduler": "cosine_with_warmup", 46 | }, 47 | "training": { 48 | "num_epochs": 100, 49 | "num_warmup_epochs": 20, 50 | "learning_rate": 0.03, 51 | "gpus": [0], 52 | "batchsize": 256, 53 | }, 54 | "testing": { 55 | "test_subject_ids": np.arange(1,71), 56 | "kfolds": np.arange(0,3), 57 | }, 58 | "seed": 1234 59 | } 60 | main_logger = Logger(filename_postfix=config["experiment_name"]) 61 | main_logger.write_to_log("Config") 62 | main_logger.write_to_log(config) 63 | config = Config(config) 64 | 65 | seed_everything(config.seed) 66 | 67 | # define custom preprocessing steps 68 | def func_preprocessing(data): 69 | data_x = data.data 70 | data_x = pick_channels(data_x, channel_names=data.channel_names, selected_channels=config.data.selected_channels) 71 | data_x = notch_filter(data_x, sampling_rate=data.sampling_rate, notch_freq=50.0) 72 | data_x = butter_bandpass_filter(data_x, lowcut=7, highcut=90, sampling_rate=data.sampling_rate, order=6) 73 | start_t = 35 74 | end_t = start_t + (config.data.duration * data.sampling_rate) 75 | data_x = data_x[:,:,:,start_t:end_t] 76 | data.set_data(data_x) 77 | 78 | # load data 79 | data = MultipleSubjects( 80 | dataset=Beta, 81 | root=os.path.join(path,config.data.root), 82 | subject_ids=config.data.load_subject_ids, 83 | func_preprocessing=func_preprocessing, 84 | verbose=True, 85 | ) 86 | 87 | num_channel = data.data.shape[2] 88 | num_classes = data.stimulus_frequencies.shape[0] 89 | signal_length = data.data.shape[3] 90 | 91 | 92 | def train_test_subject_kfold(data, config, test_subject_id, kfold_k=0): 93 | 94 | ## init data 95 | # train_dataset, val_dataset, test_dataset = leave_one_subject_out(data, test_subject_id=test_subject_id, kfold_k=kfold_k) 96 | # train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_k=kfold_k) 97 | # train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 98 | # val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False) 99 | # test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 100 | # no kfold 101 | train_dataset, test_dataset = data.get_train_test_dataset(test_subject_id=test_subject_id) 102 | train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 103 | test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 104 | 105 | ## init model 106 | 107 | eegnet = MultitaskSSVEPClassifier(num_channel=num_channel, num_classes=num_classes, signal_length=signal_length) 108 | 109 | model = LightningModelClassifier( 110 | optimizer=config.model.optimizer, 111 | scheduler=config.model.scheduler, 112 | optimizer_learning_rate=config.training.learning_rate, 113 | scheduler_warmup_epochs=config.training.num_warmup_epochs, 114 | ) 115 | 116 | model.build_model(model=eegnet) 117 | 118 | ## train 119 | 120 | sub_dir = "sub"+ str(test_subject_id) +"_k"+ str(kfold_k) 121 | logger_tb = TensorBoardLogger(save_dir="tensorboard_logs", name=config.experiment_name, sub_dir=sub_dir) 122 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 123 | 124 | trainer = Trainer(max_epochs=config.training.num_epochs, gpus=config.training.gpus, logger=logger_tb, progress_bar_refresh_rate=0, weights_summary=None, callbacks=[lr_monitor]) 125 | # trainer.fit(model, train_loader, val_loader) 126 | trainer.fit(model, train_loader) 127 | 128 | ## test 129 | 130 | result = trainer.test(dataloaders=test_loader, verbose=False) 131 | test_acc = result[0]['test_acc_epoch'] 132 | 133 | return test_acc 134 | 135 | #### 136 | 137 | main_logger.write_to_log("Begin", break_line=True) 138 | 139 | test_results_acc = {} 140 | means = [] 141 | 142 | def k_fold_train_test_all_subjects(): 143 | 144 | for test_subject_id in config.testing.test_subject_ids: 145 | print() 146 | print("running test_subject_id:", test_subject_id) 147 | 148 | if test_subject_id not in test_results_acc: 149 | test_results_acc[test_subject_id] = [] 150 | 151 | # k-fold 152 | # for k in config.testing.kfolds: 153 | # test_acc = train_test_subject_kfold(data, config, test_subject_id, kfold_k=k) 154 | # test_results_acc[test_subject_id].append(test_acc) 155 | # mean_acc = np.mean(test_results_acc[test_subject_id]) 156 | # means.append(mean_acc) 157 | # one fold: 158 | mean_acc = train_test_subject_kfold(data, config, test_subject_id) 159 | means.append(mean_acc) 160 | 161 | this_result = { 162 | "test_subject_id": test_subject_id, 163 | "mean_acc": mean_acc, 164 | "acc": test_results_acc[test_subject_id], 165 | } 166 | print(this_result) 167 | main_logger.write_to_log(this_result) 168 | 169 | k_fold_train_test_all_subjects() 170 | 171 | mean_acc = np.mean(means) 172 | print() 173 | print("mean all", mean_acc) 174 | main_logger.write_to_log("Mean acc: "+str(mean_acc), break_line=True) 175 | -------------------------------------------------------------------------------- /examples/multitask_jfpm.py: -------------------------------------------------------------------------------- 1 | # for running locally 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | # path = os.path.join(cwd, "..\\..\\") 6 | path = cwd 7 | sys.path.append(path) 8 | 9 | # imports 10 | import numpy as np 11 | import logging 12 | logging.getLogger('lightning').setLevel(0) 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | 16 | import torch 17 | from torch.utils.data import DataLoader 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | import pytorch_lightning 22 | pytorch_lightning.utilities.distributed.log.setLevel(logging.ERROR) 23 | from pytorch_lightning import Trainer, seed_everything 24 | from pytorch_lightning.callbacks import LearningRateMonitor 25 | from pytorch_lightning.loggers import TensorBoardLogger 26 | 27 | from splearn.data import MultipleSubjects, JFPM 28 | from splearn.utils import Logger, Config 29 | from splearn.filter.butterworth import butter_bandpass_filter 30 | from splearn.filter.notch import notch_filter 31 | from splearn.filter.channels import pick_channels 32 | from splearn.nn.models import MultitaskSSVEPClassifier 33 | from splearn.nn.base import LightningModelClassifier 34 | 35 | config = { 36 | "experiment_name": "multitask_jfpm_nokfold", 37 | "data": { 38 | "load_subject_ids": np.arange(1,11), 39 | "root": "../data/jfpm", 40 | "duration": 1, 41 | }, 42 | "model": { 43 | "optimizer": "adamw", 44 | "scheduler": "cosine_with_warmup", 45 | }, 46 | "training": { 47 | "num_epochs": 100, 48 | "num_warmup_epochs": 20, 49 | "learning_rate": 0.03, 50 | "gpus": [0], 51 | "batchsize": 256, 52 | }, 53 | "testing": { 54 | "test_subject_ids": np.arange(1,11), 55 | "kfolds": np.arange(0,3), 56 | }, 57 | "seed": 1234 58 | } 59 | main_logger = Logger(filename_postfix=config["experiment_name"]) 60 | main_logger.write_to_log("Config") 61 | main_logger.write_to_log(config) 62 | config = Config(config) 63 | 64 | seed_everything(config.seed) 65 | 66 | # define custom preprocessing steps 67 | def func_preprocessing(data): 68 | data_x = data.data 69 | data_x = notch_filter(data_x, sampling_rate=data.sampling_rate, notch_freq=50.0) 70 | data_x = butter_bandpass_filter(data_x, lowcut=7, highcut=90, sampling_rate=data.sampling_rate, order=6) 71 | start_t = 35 72 | end_t = start_t + (config.data.duration * data.sampling_rate) 73 | data_x = data_x[:,:,:,start_t:end_t] 74 | data.set_data(data_x) 75 | 76 | # load data 77 | data = MultipleSubjects( 78 | dataset=JFPM, 79 | root=os.path.join(path,config.data.root), 80 | subject_ids=config.data.load_subject_ids, 81 | func_preprocessing=func_preprocessing, 82 | verbose=True, 83 | ) 84 | 85 | num_channel = data.data.shape[2] 86 | num_classes = data.stimulus_frequencies.shape[0] 87 | signal_length = data.data.shape[3] 88 | 89 | 90 | def train_test_subject_kfold(data, config, test_subject_id, kfold_k=0): 91 | 92 | ## init data 93 | # train_dataset, val_dataset, test_dataset = leave_one_subject_out(data, test_subject_id=test_subject_id, kfold_k=kfold_k) 94 | # train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_k=kfold_k) 95 | # train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 96 | # val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False) 97 | # test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 98 | # no kfold 99 | train_dataset, test_dataset = data.get_train_test_dataset(test_subject_id=test_subject_id) 100 | train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 101 | test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 102 | loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 103 | 104 | ## init model 105 | 106 | eegnet = MultitaskSSVEPClassifier(num_channel=num_channel, num_classes=num_classes, signal_length=signal_length) 107 | 108 | model = LightningModelClassifier( 109 | optimizer=config.model.optimizer, 110 | scheduler=config.model.scheduler, 111 | optimizer_learning_rate=config.training.learning_rate, 112 | scheduler_warmup_epochs=config.training.num_warmup_epochs, 113 | ) 114 | 115 | model.build_model(model=eegnet) 116 | 117 | ## train 118 | 119 | sub_dir = "sub"+ str(test_subject_id) +"_k"+ str(kfold_k) 120 | logger_tb = TensorBoardLogger(save_dir="tensorboard_logs", name=config.experiment_name, sub_dir=sub_dir) 121 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 122 | 123 | trainer = Trainer(max_epochs=config.training.num_epochs, gpus=config.training.gpus, logger=logger_tb, progress_bar_refresh_rate=0, weights_summary=None, callbacks=[lr_monitor]) 124 | # trainer.fit(model, train_loader, val_loader) 125 | trainer.fit(model, train_loader) 126 | 127 | ## test 128 | 129 | result = trainer.test(dataloaders=test_loader, verbose=False) 130 | test_acc = result[0]['test_acc_epoch'] 131 | 132 | return test_acc 133 | 134 | #### 135 | 136 | main_logger.write_to_log("Begin", break_line=True) 137 | 138 | test_results_acc = {} 139 | means = [] 140 | 141 | def k_fold_train_test_all_subjects(): 142 | 143 | for test_subject_id in config.testing.test_subject_ids: 144 | print() 145 | print("running test_subject_id:", test_subject_id) 146 | 147 | if test_subject_id not in test_results_acc: 148 | test_results_acc[test_subject_id] = [] 149 | 150 | # k-fold 151 | # for k in config.testing.kfolds: 152 | # test_acc = train_test_subject_kfold(data, config, test_subject_id, kfold_k=k) 153 | # test_results_acc[test_subject_id].append(test_acc) 154 | # mean_acc = np.mean(test_results_acc[test_subject_id]) 155 | # means.append(mean_acc) 156 | # one fold: 157 | mean_acc = train_test_subject_kfold(data, config, test_subject_id) 158 | means.append(mean_acc) 159 | 160 | this_result = { 161 | "test_subject_id": test_subject_id, 162 | "mean_acc": mean_acc, 163 | "acc": test_results_acc[test_subject_id], 164 | } 165 | print(this_result) 166 | main_logger.write_to_log(this_result) 167 | 168 | k_fold_train_test_all_subjects() 169 | 170 | mean_acc = np.mean(means) 171 | print() 172 | print("mean all", mean_acc) 173 | main_logger.write_to_log("Mean acc: "+str(mean_acc), break_line=True) 174 | -------------------------------------------------------------------------------- /examples/tcnn_benchmark.py: -------------------------------------------------------------------------------- 1 | # for running locally 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | # path = os.path.join(cwd, "..\\..\\") 6 | path = cwd 7 | sys.path.append(path) 8 | 9 | # imports 10 | import numpy as np 11 | import logging 12 | logging.getLogger('lightning').setLevel(0) 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | 16 | import torch 17 | from torch.utils.data import DataLoader 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | import pytorch_lightning 22 | pytorch_lightning.utilities.distributed.log.setLevel(logging.ERROR) 23 | from pytorch_lightning import Trainer, seed_everything 24 | from pytorch_lightning.callbacks import LearningRateMonitor 25 | from pytorch_lightning.loggers import TensorBoardLogger 26 | 27 | from splearn.data import MultipleSubjects, Benchmark 28 | from splearn.utils import Logger, Config 29 | from splearn.filter.butterworth import butter_bandpass_filter 30 | from splearn.filter.notch import notch_filter 31 | from splearn.filter.channels import pick_channels 32 | from splearn.nn.models import TimeDomainBasedCNN 33 | from splearn.nn.base import LightningModelClassifier 34 | 35 | config = { 36 | "experiment_name": "tcnn_benchmark", 37 | "data": { 38 | "load_subject_ids": np.arange(1,36), 39 | "root": "../data/hsssvep", 40 | "selected_channels": ["PZ", "PO5", "PO3", "POz", "PO4", "PO6", "O1", "Oz", "O2"], 41 | "duration": 1, 42 | }, 43 | "model": { 44 | "optimizer": "adamw", 45 | "scheduler": "cosine_with_warmup", 46 | }, 47 | "training": { 48 | "num_epochs": 100, 49 | "num_warmup_epochs": 20, 50 | "learning_rate": 0.03, 51 | "gpus": [0], 52 | "batchsize": 256, 53 | }, 54 | "testing": { 55 | "test_subject_ids": np.arange(1,36), 56 | "kfolds": np.arange(0,3), 57 | }, 58 | "seed": 1234 59 | } 60 | main_logger = Logger(filename_postfix=config["experiment_name"]) 61 | main_logger.write_to_log("Config") 62 | main_logger.write_to_log(config) 63 | config = Config(config) 64 | 65 | seed_everything(config.seed) 66 | 67 | # define custom preprocessing steps 68 | def func_preprocessing(data): 69 | data_x = data.data 70 | data_x = pick_channels(data_x, channel_names=data.channel_names, selected_channels=config.data.selected_channels) 71 | data_x = notch_filter(data_x, sampling_rate=data.sampling_rate, notch_freq=50.0) 72 | data_x = butter_bandpass_filter(data_x, lowcut=7, highcut=90, sampling_rate=data.sampling_rate, order=6) 73 | start_t = 35 74 | end_t = start_t + (config.data.duration * data.sampling_rate) 75 | data_x = data_x[:,:,:,start_t:end_t] 76 | data.set_data(data_x) 77 | 78 | # load data 79 | data = MultipleSubjects( 80 | dataset=Benchmark, 81 | root=os.path.join(path,config.data.root), 82 | subject_ids=config.data.load_subject_ids, 83 | func_preprocessing=func_preprocessing, 84 | verbose=True, 85 | ) 86 | 87 | num_channel = data.data.shape[2] 88 | num_classes = data.stimulus_frequencies.shape[0] 89 | signal_length = data.data.shape[3] 90 | 91 | 92 | def train_test_subject_kfold(data, config, test_subject_id, kfold_k=0): 93 | 94 | ## init data 95 | 96 | # train_dataset, val_dataset, test_dataset = leave_one_subject_out(data, test_subject_id=test_subject_id, kfold_k=kfold_k) 97 | train_dataset, val_dataset, test_dataset = data.get_train_val_test_dataset(test_subject_id=test_subject_id, kfold_k=kfold_k) 98 | train_loader = DataLoader(train_dataset, batch_size=config.training.batchsize, shuffle=True) 99 | val_loader = DataLoader(val_dataset, batch_size=config.training.batchsize, shuffle=False) 100 | test_loader = DataLoader(test_dataset, batch_size=config.training.batchsize, shuffle=False) 101 | 102 | ## init model 103 | 104 | # eegnet = CompactEEGNet(num_channel=num_channel, num_classes=num_classes, signal_length=signal_length) 105 | base_model = TimeDomainBasedCNN(num_classes=num_classes, signal_length=signal_length) 106 | 107 | model = LightningModelClassifier( 108 | optimizer=config.model.optimizer, 109 | scheduler=config.model.scheduler, 110 | optimizer_learning_rate=config.training.learning_rate, 111 | scheduler_warmup_epochs=config.training.num_warmup_epochs, 112 | ) 113 | 114 | model.build_model(model=base_model) 115 | 116 | ## train 117 | 118 | sub_dir = "sub"+ str(test_subject_id) +"_k"+ str(kfold_k) 119 | logger_tb = TensorBoardLogger(save_dir="tensorboard_logs", name=config.experiment_name, sub_dir=sub_dir) 120 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 121 | 122 | trainer = Trainer(max_epochs=config.training.num_epochs, gpus=config.training.gpus, logger=logger_tb, progress_bar_refresh_rate=0, weights_summary=None, callbacks=[lr_monitor]) 123 | trainer.fit(model, train_loader, val_loader) 124 | 125 | ## test 126 | 127 | result = trainer.test(dataloaders=test_loader, verbose=False) 128 | test_acc = result[0]['test_acc_epoch'] 129 | 130 | return test_acc 131 | 132 | #### 133 | 134 | main_logger.write_to_log("Begin", break_line=True) 135 | 136 | test_results_acc = {} 137 | means = [] 138 | 139 | def k_fold_train_test_all_subjects(): 140 | 141 | for test_subject_id in config.testing.test_subject_ids: 142 | print() 143 | print("running test_subject_id:", test_subject_id) 144 | 145 | if test_subject_id not in test_results_acc: 146 | test_results_acc[test_subject_id] = [] 147 | 148 | for k in config.testing.kfolds: 149 | 150 | test_acc = train_test_subject_kfold(data, config, test_subject_id, kfold_k=k) 151 | 152 | test_results_acc[test_subject_id].append(test_acc) 153 | 154 | mean_acc = np.mean(test_results_acc[test_subject_id]) 155 | means.append(mean_acc) 156 | 157 | this_result = { 158 | "test_subject_id": test_subject_id, 159 | "mean_acc": mean_acc, 160 | "acc": test_results_acc[test_subject_id], 161 | } 162 | print(this_result) 163 | main_logger.write_to_log(this_result) 164 | 165 | k_fold_train_test_all_subjects() 166 | 167 | mean_acc = np.mean(means) 168 | print() 169 | print("mean all", mean_acc) 170 | main_logger.write_to_log("Mean acc: "+str(mean_acc), break_line=True) 171 | -------------------------------------------------------------------------------- /examples/trca_benchmark.py: -------------------------------------------------------------------------------- 1 | # for running locally 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | # path = os.path.join(cwd, "..\\..\\") 6 | path = cwd 7 | sys.path.append(path) 8 | 9 | # imports 10 | import numpy as np 11 | import logging 12 | logging.getLogger('lightning').setLevel(0) 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | 16 | 17 | from splearn.data import MultipleSubjects, Benchmark 18 | from splearn.utils import Logger, Config 19 | from splearn.filter.butterworth import butter_bandpass_filter 20 | from splearn.filter.notch import notch_filter 21 | from splearn.filter.channels import pick_channels 22 | from splearn.cross_decomposition.trca import TRCA 23 | from splearn.cross_validate.leave_one_out import block_evaluation 24 | 25 | config = { 26 | "experiment_name": "trcaEnsemble_benchmark", 27 | "data": { 28 | "load_subject_ids": np.arange(1,36), 29 | "root": "../data/hsssvep", 30 | "selected_channels": ["PZ", "PO5", "PO3", "POz", "PO4", "PO6", "O1", "Oz", "O2"], 31 | "duration": 1, 32 | }, 33 | "trca": { 34 | "ensemble": True 35 | }, 36 | "seed": 1234 37 | } 38 | main_logger = Logger(filename_postfix=config["experiment_name"]) 39 | main_logger.write_to_log("Config") 40 | main_logger.write_to_log(config) 41 | config = Config(config) 42 | 43 | # define custom preprocessing steps 44 | def func_preprocessing(data): 45 | data_x = data.data 46 | data_x = pick_channels(data_x, channel_names=data.channel_names, selected_channels=config.data.selected_channels) 47 | data_x = notch_filter(data_x, sampling_rate=data.sampling_rate, notch_freq=50.0) 48 | data_x = butter_bandpass_filter(data_x, lowcut=7, highcut=90, sampling_rate=data.sampling_rate, order=6) 49 | start_t = 35 50 | end_t = start_t + (config.data.duration * data.sampling_rate) 51 | data_x = data_x[:,:,:,start_t:end_t] 52 | data.set_data(data_x) 53 | 54 | # load data 55 | data = MultipleSubjects( 56 | dataset=Benchmark, 57 | root=os.path.join(path,config.data.root), 58 | subject_ids=config.data.load_subject_ids, 59 | func_preprocessing=func_preprocessing, 60 | verbose=True, 61 | ) 62 | 63 | num_channel = data.data.shape[2] 64 | num_classes = data.stimulus_frequencies.shape[0] 65 | signal_length = data.data.shape[3] 66 | 67 | 68 | def leave_one_block_evaluation(classifier, X, Y, block_seq_labels=None): 69 | test_results_acc = [] 70 | blocks, targets, channels, samples = X.shape 71 | 72 | main_logger.write_to_log("Begin", break_line=True) 73 | 74 | for block_i in range(blocks): 75 | test_acc = block_evaluation(classifier, X, Y, block_i) 76 | test_results_acc.append(test_acc) 77 | 78 | this_result = { 79 | "test_subject_id": block_i+1, 80 | "acc": test_acc, 81 | } 82 | 83 | main_logger.write_to_log(this_result) 84 | 85 | mean_acc = np.array(test_results_acc).mean().round(3)*100 86 | 87 | print(f'Mean test accuracy: {mean_acc}%') 88 | 89 | main_logger.write_to_log("Mean acc: "+str(mean_acc), break_line=True) 90 | 91 | 92 | trca_classifier = TRCA(sampling_rate=data.sampling_rate, ensemble=config.trca.ensemble) 93 | print("data:", data.data.shape) 94 | print("targets:", data.targets.shape) 95 | leave_one_block_evaluation(classifier=trca_classifier, X=data.data, Y=data.targets) 96 | -------------------------------------------------------------------------------- /examples/trca_beta.py: -------------------------------------------------------------------------------- 1 | # for running locally 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | # path = os.path.join(cwd, "..\\..\\") 6 | path = cwd 7 | sys.path.append(path) 8 | 9 | # imports 10 | import numpy as np 11 | import logging 12 | logging.getLogger('lightning').setLevel(0) 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | 16 | 17 | from splearn.data import MultipleSubjects, Beta 18 | from splearn.utils import Logger, Config 19 | from splearn.filter.butterworth import butter_bandpass_filter 20 | from splearn.filter.notch import notch_filter 21 | from splearn.filter.channels import pick_channels 22 | from splearn.cross_decomposition.trca import TRCA 23 | from splearn.cross_validate.leave_one_out import block_evaluation 24 | 25 | config = { 26 | "experiment_name": "trcaEnsemble_beta", 27 | "data": { 28 | "load_subject_ids": np.arange(1,71), 29 | "root": "../data/beta", 30 | "selected_channels": ["PZ", "PO5", "PO3", "POz", "PO4", "PO6", "O1", "Oz", "O2"], 31 | "duration": 1, 32 | }, 33 | "trca": { 34 | "ensemble": True 35 | }, 36 | "seed": 1234 37 | } 38 | main_logger = Logger(filename_postfix=config["experiment_name"]) 39 | main_logger.write_to_log("Config") 40 | main_logger.write_to_log(config) 41 | config = Config(config) 42 | 43 | # define custom preprocessing steps 44 | def func_preprocessing(data): 45 | data_x = data.data 46 | data_x = pick_channels(data_x, channel_names=data.channel_names, selected_channels=config.data.selected_channels) 47 | data_x = notch_filter(data_x, sampling_rate=data.sampling_rate, notch_freq=50.0) 48 | data_x = butter_bandpass_filter(data_x, lowcut=7, highcut=90, sampling_rate=data.sampling_rate, order=6) 49 | start_t = 35 50 | end_t = start_t + (config.data.duration * data.sampling_rate) 51 | data_x = data_x[:,:,:,start_t:end_t] 52 | data.set_data(data_x) 53 | 54 | # load data 55 | data = MultipleSubjects( 56 | dataset=Beta, 57 | root=os.path.join(path,config.data.root), 58 | subject_ids=config.data.load_subject_ids, 59 | func_preprocessing=func_preprocessing, 60 | verbose=True, 61 | ) 62 | 63 | num_channel = data.data.shape[2] 64 | num_classes = data.stimulus_frequencies.shape[0] 65 | signal_length = data.data.shape[3] 66 | 67 | 68 | def leave_one_block_evaluation(classifier, X, Y, block_seq_labels=None): 69 | test_results_acc = [] 70 | blocks, targets, channels, samples = X.shape 71 | 72 | main_logger.write_to_log("Begin", break_line=True) 73 | 74 | for block_i in range(blocks): 75 | test_acc = block_evaluation(classifier, X, Y, block_i) 76 | test_results_acc.append(test_acc) 77 | 78 | this_result = { 79 | "test_subject_id": block_i+1, 80 | "acc": test_acc, 81 | } 82 | 83 | main_logger.write_to_log(this_result) 84 | 85 | mean_acc = np.array(test_results_acc).mean().round(3)*100 86 | 87 | print(f'Mean test accuracy: {mean_acc}%') 88 | 89 | main_logger.write_to_log("Mean acc: "+str(mean_acc), break_line=True) 90 | 91 | 92 | trca_classifier = TRCA(sampling_rate=data.sampling_rate, ensemble=config.trca.ensemble) 93 | print("data:", data.data.shape) 94 | print("targets:", data.targets.shape) 95 | leave_one_block_evaluation(classifier=trca_classifier, X=data.data, Y=data.targets) 96 | -------------------------------------------------------------------------------- /examples/trca_jfpm.py: -------------------------------------------------------------------------------- 1 | # for running locally 2 | import os 3 | cwd = os.getcwd() 4 | import sys 5 | # path = os.path.join(cwd, "..\\..\\") 6 | path = cwd 7 | sys.path.append(path) 8 | 9 | # imports 10 | import numpy as np 11 | import logging 12 | logging.getLogger('lightning').setLevel(0) 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | 16 | 17 | from splearn.data import MultipleSubjects, JFPM 18 | from splearn.utils import Logger, Config 19 | from splearn.filter.butterworth import butter_bandpass_filter 20 | from splearn.filter.notch import notch_filter 21 | from splearn.cross_decomposition.trca import TRCA 22 | from splearn.cross_validate.leave_one_out import block_evaluation 23 | 24 | config = { 25 | "experiment_name": "trcaEnsemble_jfpm", 26 | "data": { 27 | "load_subject_ids": np.arange(1,11), 28 | "root": "../data/jfpm", 29 | "duration": 1, 30 | }, 31 | "trca": { 32 | "ensemble": True 33 | }, 34 | "seed": 1234 35 | } 36 | main_logger = Logger(filename_postfix=config["experiment_name"]) 37 | main_logger.write_to_log("Config") 38 | main_logger.write_to_log(config) 39 | config = Config(config) 40 | 41 | # define custom preprocessing steps 42 | def func_preprocessing(data): 43 | data_x = data.data 44 | data_x = notch_filter(data_x, sampling_rate=data.sampling_rate, notch_freq=50.0) 45 | data_x = butter_bandpass_filter(data_x, lowcut=7, highcut=90, sampling_rate=data.sampling_rate, order=6) 46 | start_t = 35 47 | end_t = start_t + (config.data.duration * data.sampling_rate) 48 | data_x = data_x[:,:,:,start_t:end_t] 49 | data.set_data(data_x) 50 | 51 | # load data 52 | data = MultipleSubjects( 53 | dataset=JFPM, 54 | root=os.path.join(path,config.data.root), 55 | subject_ids=config.data.load_subject_ids, 56 | func_preprocessing=func_preprocessing, 57 | verbose=True, 58 | ) 59 | 60 | num_channel = data.data.shape[2] 61 | num_classes = data.stimulus_frequencies.shape[0] 62 | signal_length = data.data.shape[3] 63 | 64 | 65 | def leave_one_block_evaluation(classifier, X, Y, block_seq_labels=None): 66 | test_results_acc = [] 67 | blocks, targets, channels, samples = X.shape 68 | 69 | main_logger.write_to_log("Begin", break_line=True) 70 | 71 | for block_i in range(blocks): 72 | test_acc = block_evaluation(classifier, X, Y, block_i) 73 | test_results_acc.append(test_acc) 74 | 75 | this_result = { 76 | "test_subject_id": block_i+1, 77 | "acc": test_acc, 78 | } 79 | 80 | main_logger.write_to_log(this_result) 81 | 82 | mean_acc = np.array(test_results_acc).mean().round(3)*100 83 | 84 | print(f'Mean test accuracy: {mean_acc}%') 85 | 86 | main_logger.write_to_log("Mean acc: "+str(mean_acc), break_line=True) 87 | 88 | 89 | trca_classifier = TRCA(sampling_rate=data.sampling_rate, ensemble=config.trca.ensemble) 90 | print("data:", data.data.shape) 91 | print("targets:", data.targets.shape) 92 | leave_one_block_evaluation(classifier=trca_classifier, X=data.data, Y=data.targets) 93 | -------------------------------------------------------------------------------- /examples/trca_sample.py: -------------------------------------------------------------------------------- 1 | # this code is for reproducing the results in: https://github.com/mnakanishi/TRCA-SSVEP 2 | 3 | # for running locally 4 | import os 5 | cwd = os.getcwd() 6 | import sys 7 | # path = os.path.join(cwd, "..\\..\\") 8 | path = cwd 9 | sys.path.append(path) 10 | 11 | # imports 12 | import numpy as np 13 | import logging 14 | logging.getLogger('lightning').setLevel(0) 15 | import warnings 16 | warnings.filterwarnings('ignore') 17 | 18 | 19 | from splearn.data import SampleSSVEPData 20 | from splearn.utils import Logger, Config 21 | from splearn.cross_decomposition.trca import TRCA 22 | from splearn.cross_validate.leave_one_out import block_evaluation 23 | 24 | 25 | main_logger = Logger(filename_postfix="trca sample") 26 | main_logger.write_to_log("Config") 27 | 28 | # load data 29 | 30 | data = SampleSSVEPData() 31 | print(data.data.shape) 32 | 33 | # method 1 34 | eeg = data.get_data() 35 | labels = data.get_targets() 36 | trca_classifier = TRCA(sampling_rate=data.sampling_rate) 37 | test_accuracies = trca_classifier.leave_one_block_evaluation(eeg, labels) 38 | 39 | # method 2 40 | 41 | def leave_one_block_evaluation(classifier, X, Y, block_seq_labels=None): 42 | test_results_acc = [] 43 | blocks, targets, channels, samples = X.shape 44 | 45 | main_logger.write_to_log("Begin", break_line=True) 46 | 47 | for block_i in range(blocks): 48 | test_acc = block_evaluation(classifier, X, Y, block_i) 49 | test_results_acc.append(test_acc) 50 | 51 | this_result = { 52 | "test_subject_id": block_i+1, 53 | "acc": test_acc, 54 | } 55 | 56 | main_logger.write_to_log(this_result) 57 | 58 | mean_acc = np.array(test_results_acc).mean().round(3)*100 59 | 60 | print(f'Mean test accuracy: {mean_acc}%') 61 | 62 | main_logger.write_to_log("Mean acc: "+str(mean_acc), break_line=True) 63 | 64 | 65 | trca_classifier = TRCA(sampling_rate=data.sampling_rate) 66 | leave_one_block_evaluation(classifier=trca_classifier, X=data.data, Y=data.targets) 67 | 68 | 69 | # expected output: 70 | # Block: 1 | Train acc: 100.00% | Test acc: 97.50% 71 | # Block: 2 | Train acc: 100.00% | Test acc: 100.00% 72 | # Block: 3 | Train acc: 100.00% | Test acc: 100.00% 73 | # Block: 4 | Train acc: 100.00% | Test acc: 100.00% 74 | # Block: 5 | Train acc: 100.00% | Test acc: 97.50% 75 | # Block: 6 | Train acc: 100.00% | Test acc: 100.00% 76 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | scipy 4 | matplotlib 5 | sklearn 6 | pytorch-lightning 7 | torchmetrics 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="splearn", 8 | version="0.1a1", 9 | author="Jingles", 10 | author_email="jinglescode@gmail.com", 11 | description="splearn: package for signal processing and machine learning with Python.", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/jinglescode/python-signal-processing", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: BSD License", 19 | "Operating System :: OS Independent", 20 | ], 21 | python_requires='>=3.6', 22 | ) 23 | -------------------------------------------------------------------------------- /splearn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinglescode/python-signal-processing/77e448a163caffabf36a81f839871def5441214e/splearn/__init__.py -------------------------------------------------------------------------------- /splearn/classes/classifier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from splearn.cross_validate.leave_one_out import block_evaluation 3 | 4 | class Classifier(): 5 | 6 | def __init__(self): 7 | pass 8 | 9 | def predict(self, X): 10 | return None 11 | 12 | def fit(self, X, Y): 13 | pass 14 | 15 | def leave_one_block_evaluation(classifier, X, Y, block_seq_labels=None): 16 | test_results_acc = [] 17 | blocks, targets, channels, samples = X.shape 18 | 19 | for block_i in range(blocks): 20 | test_acc = block_evaluation(classifier, X, Y, block_i) 21 | test_results_acc.append(test_acc) 22 | 23 | mean_acc = np.array(test_results_acc).mean().round(3)*100 24 | 25 | print(f'Mean test accuracy: {mean_acc}%') 26 | return (mean_acc, test_results_acc) 27 | -------------------------------------------------------------------------------- /splearn/cross_decomposition/__init__.py: -------------------------------------------------------------------------------- 1 | from .trca import TRCA 2 | from .cca import CCA 3 | -------------------------------------------------------------------------------- /splearn/cross_decomposition/cca.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Canonical Correlation Analysis (CCA). http://en.wikipedia.org/wiki/Canonical_correlation 3 | """ 4 | import numpy as np 5 | from sklearn.metrics import confusion_matrix 6 | import functools 7 | from ..classes.classifier import Classifier 8 | from .reference_frequencies import generate_reference_signals 9 | 10 | 11 | class CCA(Classifier): 12 | r""" 13 | Calculates the canonical correlation coefficient and 14 | corresponding weights which maximize a correlation coefficient 15 | between linear combinations of the two specified multivariable 16 | signals. 17 | 18 | Args: 19 | sampling_rate: int 20 | Sampling frequency 21 | target_frequencies : array 22 | Frequencies for SSVEP classification 23 | signal_size : int 24 | Window/segment length in time samples 25 | sampling_rate : int 26 | Sampling frequency 27 | num_harmonics : int, default: 2 28 | Generate till n-th harmonics 29 | """ 30 | def __init__(self, sampling_rate, target_frequencies, signal_size, num_harmonics=2): 31 | self.sampling_rate = sampling_rate 32 | 33 | self.reference_frequencies = generate_reference_signals( 34 | target_frequencies, 35 | size=signal_size, 36 | sampling_rate=sampling_rate, 37 | num_harmonics=num_harmonics 38 | ) 39 | self.can_train = False 40 | 41 | def predict(self, X): 42 | r""" 43 | Predict the label for each trial. 44 | 45 | Args: 46 | X : ndarray, shape (trial, channels, samples) 47 | 3-dim signal data by trial 48 | Returns: 49 | results : array 50 | Predicted targets 51 | """ 52 | predicted_class, _, _, _, _ = perform_cca(X, self.reference_frequencies, labels=None) 53 | return predicted_class 54 | 55 | 56 | def calculate_cca(dat_x, dat_y, time_axis=-2): 57 | r""" 58 | Calculate the Canonical Correlation Analysis (CCA). 59 | This method calculates the canonical correlation coefficient and 60 | corresponding weights which maximize a correlation coefficient 61 | between linear combinations of the two specified multivariable 62 | signals. 63 | Args: 64 | dat_x : continuous Data object 65 | these data should have the same length on the time axis. 66 | dat_y : continuous Data object 67 | these data should have the same length on the time axis. 68 | time_axis : int, optional 69 | the index of the time axis in ``dat_x`` and ``dat_y``. 70 | Returns: 71 | rho : float 72 | the canonical correlation coefficient. 73 | w_x, w_y : 1d array 74 | the weights for mapping from the specified multivariable signals 75 | to canonical variables. 76 | Raises: 77 | AssertionError : 78 | If: 79 | * ``dat_x`` and ``dat_y`` is not continuous Data object 80 | * the length of ``dat_x`` and ``dat_y`` is different on the 81 | ``time_axis`` 82 | Dependencies: 83 | functools : functools package 84 | np : numpy package 85 | Reference: 86 | https://github.com/venthur/wyrm/blob/master/wyrm/processing.py 87 | http://en.wikipedia.org/wiki/Canonical_correlation 88 | """ 89 | 90 | assert (len(dat_x.data.shape) == len(dat_y.data.shape) == 2 and 91 | dat_x.data.shape[time_axis] == dat_y.data.shape[time_axis]) 92 | 93 | if time_axis == 0 or time_axis == -2: 94 | x = dat_x.copy() 95 | y = dat_y.copy() 96 | else: 97 | x = dat_x.T.copy() 98 | y = dat_y.T.copy() 99 | 100 | # calculate covariances and it's inverses 101 | x -= x.mean(axis=0) 102 | y -= y.mean(axis=0) 103 | n = x.shape[0] 104 | c_xx = np.dot(x.T, x) / n 105 | c_yy = np.dot(y.T, y) / n 106 | c_xy = np.dot(x.T, y) / n 107 | c_yx = np.dot(y.T, x) / n 108 | ic_xx = np.linalg.pinv(c_xx) 109 | ic_yy = np.linalg.pinv(c_yy) 110 | # calculate w_x 111 | w, v = np.linalg.eig(functools.reduce(np.dot, [ic_xx, c_xy, ic_yy, c_yx])) 112 | w_x = v[:, np.argmax(w)].real 113 | w_x = w_x / np.sqrt(functools.reduce(np.dot, [w_x.T, c_xx, w_x])) 114 | # calculate w_y 115 | w, v = np.linalg.eig(functools.reduce(np.dot, [ic_yy, c_yx, ic_xx, c_xy])) 116 | w_y = v[:, np.argmax(w)].real 117 | w_y = w_y / np.sqrt(functools.reduce(np.dot, [w_y.T, c_yy, w_y])) 118 | # calculate rho 119 | rho = abs(functools.reduce(np.dot, [w_x.T, c_xy, w_y])) 120 | return rho, w_x, w_y 121 | 122 | 123 | def find_correlation_cca(signal, reference_signals): 124 | r""" 125 | Perform canonical correlation analysis (CCA) 126 | Args: 127 | signal : ndarray, shape (channel,time) 128 | Input signal in time domain 129 | reference_signals : ndarray, shape (len(flick_freq),2*num_harmonics,time) 130 | Required sinusoidal reference templates corresponding to the flicker frequency for SSVEP classification 131 | Returns: 132 | result : array, size: (reference_signals.shape[0]) 133 | Probability for each reference signals 134 | wx : array, size: (reference_signals.shape[0],signal.shape[0]) 135 | Wx obtain from CCA 136 | wy : array, size: (reference_signals.shape[0],signal.shape[0]) 137 | Wy obtain from CCA 138 | Dependencies: 139 | np : numpy package 140 | calculate_cca : function 141 | """ 142 | 143 | result = np.zeros(reference_signals.shape[0]) 144 | wx = np.zeros((reference_signals.shape[0],signal.shape[0])) 145 | wy = np.zeros((reference_signals.shape[0],reference_signals.shape[1])) 146 | 147 | for freq_idx in range(0, reference_signals.shape[0]): 148 | dat_y = np.squeeze(reference_signals[freq_idx, :, :]).T 149 | rho, w_x, w_y = calculate_cca(signal.T, dat_y) 150 | result[freq_idx] = rho 151 | wx[freq_idx,:] = w_x 152 | wy[freq_idx,:] = w_y 153 | return result, wx, wy 154 | 155 | 156 | def perform_cca(signal, reference_frequencies, labels=None): 157 | r""" 158 | Perform canonical correlation analysis (CCA) 159 | Args: 160 | signal : ndarray, shape (trial,channel,time) or (trial,channel,segment,time) 161 | Input signal in time domain 162 | reference_frequencies : ndarray, shape (len(flick_freq),2*num_harmonics,time) 163 | Required sinusoidal reference templates corresponding to the flicker frequency for SSVEP classification 164 | labels : ndarray shape (classes,) 165 | True labels of `signal`. Index of the classes must be match the sequence of `reference_frequencies` 166 | Returns: 167 | predicted_class : ndarray, size: (classes,) 168 | Predicted classes according to reference_frequencies 169 | accuracy : double 170 | If `labels` are given, `accuracy` denote classification accuracy 171 | predicted_probabilities : ndarray, size: (classes,) 172 | Predicted probabilities for each target 173 | wx : array, size: (reference_signals.shape[0],signal.shape[0]) 174 | Wx obtain from CCA 175 | wy : array, size: (reference_signals.shape[0],signal.shape[0]) 176 | Wy obtain from CCA 177 | Dependencies: 178 | confusion_matrix : sklearn.metrics.confusion_matrix 179 | find_correlation_cca : function 180 | """ 181 | 182 | assert (len(signal.shape) == 3 or len(signal.shape) == 4), "signal shape must be 3 or 4 dimension" 183 | 184 | actual_class = [] 185 | predicted_class = [] 186 | predicted_probabilities = [] 187 | accuracy = None 188 | Wx = [] 189 | Wy = [] 190 | 191 | for trial in range(0, signal.shape[0]): 192 | 193 | if len(signal.shape) == 3: 194 | if labels is not None: 195 | actual_class.append(labels[trial]) 196 | tmp_signal = signal[trial, :, :] 197 | 198 | result, wx, wy = find_correlation_cca(tmp_signal, reference_frequencies) 199 | predicted_class.append(np.argmax(result)) 200 | result = np.around(result, decimals=3, out=None) 201 | predicted_probabilities.append(result) 202 | Wx.append(wx) 203 | Wy.append(wy) 204 | 205 | if len(signal.shape) == 4: 206 | for segment in range(0, signal.shape[2]): 207 | 208 | if labels is not None: 209 | actual_class.append(labels[trial]) 210 | tmp_signal = signal[trial, :, segment, :] 211 | 212 | result, wx, wy = find_correlation_cca(tmp_signal, reference_frequencies) 213 | predicted_class.append(np.argmax(result)) 214 | result = np.around(result, decimals=3, out=None) 215 | predicted_probabilities.append(result) 216 | Wx.append(wx) 217 | Wy.append(wy) 218 | 219 | actual_class = np.array(actual_class) 220 | predicted_class = np.array(predicted_class) 221 | 222 | if labels is not None: 223 | # creating a confusion matrix of true versus predicted classification labels 224 | c_mat = confusion_matrix(actual_class, predicted_class) 225 | # computing the accuracy from the confusion matrix 226 | accuracy = np.divide(np.trace(c_mat), np.sum(np.sum(c_mat))) 227 | 228 | Wx = np.array(Wx) 229 | Wy = np.array(Wy) 230 | 231 | return predicted_class, accuracy, predicted_probabilities, Wx, Wy 232 | -------------------------------------------------------------------------------- /splearn/cross_decomposition/fbcca.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Wed Oct 30 10:17:50 2019 3 | @author: ALU 4 | Steady-state visual evoked potentials (SSVEPs) detection using the filter 5 | bank canonical correlation analysis (FBCCA)-based method [1]. 6 | function results = test_fbcca(eeg, list_freqs, fs, num_harms, num_fbs) 7 | Input: 8 | eeg : Input eeg data 9 | (# of targets, # of channels, Data length [sample]) 10 | list_freqs : List for stimulus frequencies 11 | fs : Sampling frequency 12 | num_harms : # of harmonics 13 | num_fbs : # of filters in filterbank analysis 14 | Output: 15 | results : The target estimated by this method 16 | Reference: 17 | [1] X. Chen, Y. Wang, S. Gao, T. -P. Jung and X. Gao, 18 | "Filter bank canonical correlation analysis for implementing a 19 | high-speed SSVEP-based brain-computer interface", 20 | J. Neural Eng., vol.12, 046008, 2015. 21 | """ 22 | from sklearn.cross_decomposition import CCA 23 | from .filterbank import filterbank 24 | from scipy.stats import pearsonr 25 | import numpy as np 26 | 27 | def fbcca(eeg, list_freqs, fs, num_harms=3, num_fbs=5): 28 | 29 | fb_coefs = np.power(np.arange(1,num_fbs+1),(-1.25)) + 0.25 30 | 31 | num_targs, _, num_smpls = eeg.shape #40 taget (means 40 fre-phase combination that we want to predict) 32 | y_ref = cca_reference(list_freqs, fs, num_smpls, num_harms) 33 | cca = CCA(n_components=1) #initilize CCA 34 | 35 | # result matrix 36 | r = np.zeros((num_fbs,num_targs)) 37 | print("r", r.shape) 38 | results = np.zeros(num_targs) 39 | 40 | for targ_i in range(num_targs): 41 | test_tmp = np.squeeze(eeg[targ_i, :, :]) #deal with one target a time 42 | for fb_i in range(num_fbs): #filter bank number, deal with different filter bank 43 | testdata = filterbank(test_tmp, fs, fb_i) #data after filtering 44 | for class_i in range(num_targs): 45 | refdata = np.squeeze(y_ref[class_i, :, :]) #pick corresponding freq target reference signal 46 | print(111, testdata.T.shape, refdata.T.shape) 47 | test_C, ref_C = cca.fit_transform(testdata.T, refdata.T) 48 | # len(row) = len(observation), len(column) = variables of each observation 49 | # number of rows should be the same, so need transpose here 50 | # output is the highest correlation linear combination of two sets 51 | print(222, test_C.shape, ref_C.shape) 52 | r_tmp, _ = pearsonr(np.squeeze(test_C), np.squeeze(ref_C)) #return r and p_value, use np.squeeze to adapt the API 53 | print(333, fb_i, class_i, r_tmp) 54 | r[fb_i, class_i] = r_tmp 55 | print(444) 56 | 57 | print(333, fb_coefs.shape, r.shape) 58 | rho = np.dot(fb_coefs, r) #weighted sum of r from all different filter banks' result 59 | tau = np.argmax(rho) #get maximum from the target as the final predict (get the index) 60 | results[targ_i] = tau #index indicate the maximum(most possible) target 61 | return results 62 | 63 | ''' 64 | Generate reference signals for the canonical correlation analysis (CCA) 65 | -based steady-state visual evoked potentials (SSVEPs) detection [1, 2]. 66 | function [ y_ref ] = cca_reference(listFreq, fs, nSmpls, nHarms) 67 | Input: 68 | listFreq : List for stimulus frequencies 69 | fs : Sampling frequency 70 | nSmpls : # of samples in an epoch 71 | nHarms : # of harmonics 72 | Output: 73 | y_ref : Generated reference signals 74 | (# of targets, 2*# of channels, Data length [sample]) 75 | Reference: 76 | [1] Z. Lin, C. Zhang, W. Wu, and X. Gao, 77 | "Frequency Recognition Based on Canonical Correlation Analysis for 78 | SSVEP-Based BCI", 79 | IEEE Trans. Biomed. Eng., 54(6), 1172-1176, 2007. 80 | [2] G. Bin, X. Gao, Z. Yan, B. Hong, and S. Gao, 81 | "An online multi-channel SSVEP-based brain-computer interface using 82 | a canonical correlation analysis method", 83 | J. Neural Eng., 6 (2009) 046002 (6pp). 84 | ''' 85 | def cca_reference(list_freqs, fs, num_smpls, num_harms=3): 86 | 87 | num_freqs = len(list_freqs) 88 | tidx = np.arange(1,num_smpls+1)/fs #time index 89 | 90 | y_ref = np.zeros((num_freqs, 2*num_harms, num_smpls)) 91 | for freq_i in range(num_freqs): 92 | tmp = [] 93 | for harm_i in range(1,num_harms+1): 94 | stim_freq = list_freqs[freq_i] #in HZ 95 | # Sin and Cos 96 | tmp.extend([np.sin(2*np.pi*tidx*harm_i*stim_freq), 97 | np.cos(2*np.pi*tidx*harm_i*stim_freq)]) 98 | y_ref[freq_i] = tmp # 2*num_harms because include both sin and cos 99 | 100 | return y_ref 101 | 102 | 103 | ''' 104 | Base on fbcca, but adapt to our input format 105 | ''' 106 | def fbcca_realtime(data, list_freqs, fs, num_harms=3, num_fbs=5): 107 | 108 | fb_coefs = np.power(np.arange(1,num_fbs+1),(-1.25)) + 0.25 109 | 110 | num_targs = len(list_freqs) 111 | _, num_smpls = data.shape 112 | 113 | y_ref = cca_reference(list_freqs, fs, num_smpls, num_harms) 114 | cca = CCA(n_components=1) #initialize CCA 115 | 116 | # result matrix 117 | r = np.zeros((num_fbs,num_targs)) 118 | 119 | for fb_i in range(num_fbs): #filter bank number, deal with different filter bank 120 | testdata = filterbank(data, fs, fb_i) #data after filtering 121 | for class_i in range(num_targs): 122 | refdata = np.squeeze(y_ref[class_i, :, :]) #pick corresponding freq target reference signal 123 | test_C, ref_C = cca.fit_transform(testdata.T, refdata.T) 124 | r_tmp, _ = pearsonr(np.squeeze(test_C), np.squeeze(ref_C)) #return r and p_value 125 | if r_tmp == np.nan: 126 | r_tmp=0 127 | r[fb_i, class_i] = r_tmp 128 | 129 | rho = np.dot(fb_coefs, r) #weighted sum of r from all different filter banks' result 130 | print(rho) #print out the correlation 131 | result = np.argmax(rho) #get maximum from the target as the final predict (get the index), and index indicates the maximum entry(most possible target) 132 | ''' Threshold ''' 133 | THRESHOLD = 2.1 134 | if abs(rho[result])>> from splearn.cross_decomposition.trca import TRCA 22 | >>> from splearn.data.sample_ssvep import SampleSSVEPData 23 | >>> from splearn.cross_validate.leave_one_out import block_evaluation 24 | >>> 25 | >>> data = SampleSSVEPData() 26 | >>> eeg = data.get_data() 27 | >>> labels = data.get_targets() 28 | >>> print("eeg.shape:", eeg.shape) 29 | >>> print("labels.shape:", labels.shape) 30 | >>> 31 | >>> trca_classifier = TRCA(sampling_rate=data.sampling_rate) 32 | >>> test_accuracies = block_evaluation(trca_classifier, X, Y, 0) 33 | """ 34 | 35 | test_accuracies = [] 36 | blocks, targets, channels, samples = X.shape 37 | 38 | for block_i in range(blocks): 39 | test_acc = block_evaluation(classifier, X, Y, block_i, block_seq_labels[block_i] if block_seq_labels is not None else None) 40 | test_accuracies.append(test_acc) 41 | 42 | print(f'Mean test accuracy: {np.array(test_accuracies).mean().round(3)*100}%') 43 | return test_accuracies 44 | 45 | def block_evaluation(classifier, X, Y, block_i, block_label=None): 46 | r""" 47 | Select a block for testing, use all other blocks for training. 48 | 49 | Args: 50 | X : ndarray, shape (blocks, targets, channels, samples) 51 | 4-dim signal data 52 | Y : ndarray, shape (blocks, targets) 53 | Targets are int, starts from 0 54 | block_i: int 55 | Index of the selected block for testing 56 | block_label : str or int 57 | Labels for this block, for printing 58 | Returns: 59 | train_acc : float 60 | Train accuracy 61 | test_acc : float 62 | Test accuracy of the selected block 63 | """ 64 | 65 | blocks, targets, channels, samples = X.shape 66 | 67 | train_acc = 0 68 | if classifier.can_train: 69 | x_train = np.delete(X, block_i, axis=0) 70 | x_train = x_train.reshape((blocks-1*targets, channels, samples)) 71 | y_train = np.delete(Y, block_i, axis=0) 72 | y_train = y_train.reshape((blocks-1*targets)) 73 | classifier.fit(x_train, y_train) 74 | # p1 = classifier.predict(x_train) 75 | # train_acc = accuracy_score(y_train, p1) 76 | 77 | x_test = X[block_i,:,:,:] 78 | y_test = Y[block_i] 79 | p2 = classifier.predict(x_test) 80 | test_acc = accuracy_score(y_test, p2) 81 | 82 | if block_label is None: 83 | block_label = 'Block:' + str(block_i+1) 84 | 85 | # if classifier.can_train: 86 | # print(f'{block_label} | Train acc: {train_acc*100:.2f}% | Test acc: {test_acc*100:.2f}%') 87 | # else: 88 | # print(f'{block_label} | Test acc: {test_acc*100:.2f}%') 89 | 90 | print(f'{block_label} | Test acc: {test_acc*100:.2f}%') 91 | 92 | return test_acc 93 | -------------------------------------------------------------------------------- /splearn/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch_dataset import PyTorchDataset, PyTorchDataset2Views 2 | from .multiple_subjects import MultipleSubjects 3 | from .benchmark import Benchmark 4 | from .beta import Beta 5 | from .openbmi import OPENBMI 6 | from .jfpm import JFPM 7 | from .sample_ssvep import SampleSSVEPData 8 | 9 | from .generate import generate_signal 10 | -------------------------------------------------------------------------------- /splearn/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.io as sio 4 | from typing import Tuple 5 | 6 | from splearn.data.pytorch_dataset import PyTorchDataset 7 | 8 | 9 | class Benchmark(PyTorchDataset): 10 | """ 11 | A Benchmark Dataset for SSVEP-Based Brain–Computer Interfaces 12 | Yijun Wang, Xiaogang Chen, Xiaorong Gao, Shangkai Gao 13 | https://ieeexplore.ieee.org/document/7740878 14 | Sampling rate: 250 Hz 15 | Targets: [8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,8.2,9.2,10.2,11.2,12.2,13.2,14.2,15.2,8.4,9.4,10.4,11.4,12.4,13.4,14.4,15.4,8.6,9.6,10.6,11.6,12.6,13.6,14.6,15.6,8.8,9.8,10.8,11.8,12.8,13.8,14.8,15.8] 16 | 17 | This dataset gathered SSVEP-BCI recordings of 35 healthy subjects (17 females, aged 17-34 years, mean age: 22 years) focusing on 40 characters flickering at different frequencies (8-15.8 Hz with an interval of 0.2 Hz). For each subject, the experiment consisted of 6 blocks. Each block contained 40 trials corresponding to all 40 characters indicated in a random order. Each trial started with a visual cue (a red square) indicating a target stimulus. The cue appeared for 0.5 s on the screen. Subjects were asked to shift their gaze to the target as soon as possible within the cue duration. Following the cue offset, all stimuli started to flicker on the screen concurrently and lasted 5 s. After stimulus offset, the screen was blank for 0.5 s before the next trial began, which allowed the subjects to have short breaks between consecutive trials. Each trial lasted a total of 6 s. To facilitate visual fixation, a red triangle appeared below the flickering target during the stimulation period. In each block, subjects were asked to avoid eye blinks during the stimulation period. To avoid visual fatigue, there was a rest for several minutes between two consecutive blocks. 18 | 19 | EEG data were acquired using a Synamps2 system (Neuroscan, Inc.) with a sampling rate of 1000 Hz. The amplifier frequency passband ranged from 0.15 Hz to 200 Hz. Sixty-four channels covered the whole scalp of the subject and were aligned according to the international 10-20 system. The ground was placed on midway between Fz and FPz. The reference was located on the vertex. Electrode impedances were kept below 10 K". To remove the common power-line noise, a notch filter at 50 Hz was applied in data recording. Event triggers generated by the computer to the amplifier and recorded on an event channel synchronized to the EEG data. 20 | 21 | The continuous EEG data was segmented into 6 s epochs (500 ms pre-stimulus, 5.5 s post-stimulus onset). The epochs were subsequently downsampled to 250 Hz. Thus each trial consisted of 1500 time points. Finally, these data were stored as double-precision floating-point values in MATLAB and were named as subject indices (i.e., S01.mat, ", S35.mat). For each file, the data loaded in MATLAB generate a 4-D matrix named "data" with dimensions of [64, 1500, 40, 6]. The four dimensions indicate "Electrode index", "Time points", "Target index", and "Block index". The electrode positions were saved in a "64-channels.loc" file. Six trials were available for each SSVEP frequency. Frequency and phase values for the 40 target indices were saved in a "Freq_Phase.mat" file. 22 | 23 | Information for all subjects was listed in a "Sub_info.txt" file. For each subject, there are five factors including "Subject Index", "Gender", "Age", "Handedness", and "Group". Subjects were divided into an "experienced" group (eight subjects, S01-S08) and a "naive" group (27 subjects, S09-S35) according to their experience in SSVEP-based BCIs. 24 | """ 25 | 26 | def __init__(self, root: str, subject_id: int, verbose: bool = False, file_prefix='S') -> None: 27 | 28 | self.root = root 29 | self.data, self.targets, self.channel_names = _load_data(self.root, subject_id, verbose, file_prefix) 30 | self.sampling_rate = 250 31 | self.stimulus_frequencies = np.array([8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,8.2,9.2,10.2,11.2,12.2,13.2,14.2,15.2,8.4,9.4,10.4,11.4,12.4,13.4,14.4,15.4,8.6,9.6,10.6,11.6,12.6,13.6,14.6,15.6,8.8,9.8,10.8,11.8,12.8,13.8,14.8,15.8]) 32 | self.targets_frequencies = self.stimulus_frequencies[self.targets] 33 | 34 | def __getitem__(self, n: int) -> Tuple[np.ndarray, int]: 35 | return (self.data[n], self.targets[n]) 36 | 37 | def __len__(self) -> int: 38 | return len(self.data) 39 | 40 | 41 | def _load_data(root, subject_id, verbose, file_prefix='S'): 42 | 43 | path = os.path.join(root, file_prefix+str(subject_id)+'.mat') 44 | data_mat = sio.loadmat(path) 45 | 46 | raw_data = data_mat['data'].copy() 47 | raw_data = np.transpose(raw_data, (2,3,0,1)) 48 | 49 | data = [] 50 | targets = [] 51 | for target_id in np.arange(raw_data.shape[0]): 52 | data.extend(raw_data[target_id]) 53 | 54 | this_target = np.array([target_id]*raw_data.shape[1]) 55 | targets.extend(this_target) 56 | 57 | data = np.array(data) 58 | 59 | # Each trial started with a 0.5-s target cue. Subjects were asked to shift their gaze to the target as soon as possible. After the cue, all stimuli started to flicker on the screen concurrently for 5 s. Then, the screen was blank for 0.5 s before the next trial began. Each trial lasted 6 s in total. 60 | data = np.array(data)[:,:,125:1375] 61 | targets = np.array(targets) 62 | 63 | channel_names = ['FP1','FPZ','FP2','AF3','AF4','F7','F5','F3','F1','FZ','F2','F4','F6','F8','FT7','FC5','FC3','FC1','FCz','FC2','FC4','FC6','FT8','T7','C5','C3','C1','Cz','C2','C4','C6','T8','M1','TP7','CP5','CP3','CP1','CPZ','CP2','CP4','CP6','TP8','M2','P7','P5','P3','P1','PZ','P2','P4','P6','P8','PO7','PO5','PO3','POz','PO4','PO6','PO8','CB1','O1','Oz','O2','CB2'] 64 | 65 | if verbose: 66 | print('Load path:', path) 67 | print('Data shape', data.shape) 68 | print('Targets shape', targets.shape) 69 | 70 | return data, targets, channel_names 71 | -------------------------------------------------------------------------------- /splearn/data/beta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.io as sio 4 | from typing import Tuple 5 | 6 | from splearn.data.pytorch_dataset import PyTorchDataset 7 | 8 | 9 | class Beta(PyTorchDataset): 10 | """ 11 | BETA: A Large Benchmark Database Toward SSVEP-BCI Application 12 | Bingchuan Liu, Xiaoshan Huang, Yijun Wang, Xiaogang Chen and Xiaorong Gao 13 | https://www.frontiersin.org/articles/10.3389/fnins.2020.00627/full 14 | Sampling rate: 250 Hz 15 | stimulus frequencies: [8.6,8.8,9.,9.2,9.4,9.6,9.8,10.,10.2,10.4,10.6,10.8,11.,11.2,11.4,11.6,11.8,12.,12.2,12.4,12.6,12.8,13.,13.2,13.4,13.6,13.8,14.,14.2,14.4,14.6,14.8,15.,15.2,15.4,15.6,15.8,8.,8.2,8.4] 16 | channel_names ['FP1', 'FPZ', 'FP2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'T8', 'M1', 'TP7', 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'M2', 'P7', 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'PO8', 'CB1', 'O1', 'OZ', 'O2', 'CB2'] 17 | """ 18 | def __init__(self, root: str, subject_id: int, verbose: bool = False, file_prefix='neuroscan_S') -> None: 19 | self.root = root 20 | self.data, self.targets, self.channel_names, self.stimulus_frequencies = _load_data(self.root, subject_id, verbose, file_prefix) 21 | self.sampling_rate = 250 22 | self.targets_frequencies = self.stimulus_frequencies[self.targets] 23 | 24 | def __getitem__(self, n: int) -> Tuple[np.ndarray, int]: 25 | return (self.data[n], self.targets[n]) 26 | 27 | def __len__(self) -> int: 28 | return len(self.data) 29 | 30 | def _load_data(root, subject_id, verbose, file_prefix='neuroscan_S'): 31 | path = os.path.join(root, file_prefix+str(subject_id)+'.mat') 32 | data_mat = sio.loadmat(path) 33 | 34 | mat_data = data_mat['data'].copy() 35 | raw_data = mat_data[0][0][0] # this is raw data, shape: (64, 750, 4, 40) 36 | raw_data = np.transpose(raw_data, (3,2,0,1)) 37 | 38 | channel_names = [] 39 | raw_channels = mat_data[0][0][1][0][0][3] 40 | for i in raw_channels: 41 | channel_names.append(i[3][0]) 42 | 43 | stimulus_frequencies = mat_data[0][0][1][0][0][4][0] 44 | 45 | data = [] 46 | targets = [] 47 | for target_id in np.arange(raw_data.shape[0]): 48 | data.extend(raw_data[target_id]) 49 | 50 | this_target = np.array([target_id]*raw_data.shape[1]) 51 | targets.extend(this_target) 52 | 53 | data = np.array(data) # (160, 64, 750) 54 | targets = np.array(targets) 55 | 56 | # Each trial comprises 0.5-s data before the event onset and 0.5-s data after the time window of 2 s or 3 s. For S1-S15, the time window is 2 s and the trial length is 3 s, whereas for S16-S70 the time window is 3 s and the trial length is 4 s. 57 | # Trials began with a 0.5s cue (a red square covering the target) for gaze shift, which was followed by flickering on all the targets, and ended with a rest time of 0.5 s. 58 | # We remove the 0.5s from start and end 59 | # We limit all trials from all subjects to 2 seconds 60 | data = np.array(data)[:,:,125:625] 61 | 62 | if verbose: 63 | print('Load path:', path) 64 | print('Data shape', data.shape) 65 | print('Targets shape', targets.shape) 66 | 67 | return data, targets, channel_names, stimulus_frequencies 68 | -------------------------------------------------------------------------------- /splearn/data/generate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Generate signals 3 | """ 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def generate_signal(length_seconds, sampling_rate, frequencies, func="sin", add_noise=0, plot=False, include_amplitude=False, normalize=False): 9 | r""" 10 | Generate a n-D array, `length_seconds` seconds signal at `sampling_rate` sampling rate. 11 | 12 | Args: 13 | length_seconds : float 14 | Duration of signal in seconds (i.e. `10` for a 10-seconds signal, `3.5` for a 3.5-seconds signal) 15 | sampling_rate : int 16 | The sampling rate of the signal. 17 | frequencies : 1 or 2 dimension python list a floats 18 | An array of floats, where each float is the desired frequencies to generate (i.e. [5, 12, 15] to generate a signal containing a 5-Hz, 12-Hz and 15-Hz) 19 | 2 dimension python list, i.e. [[5, 12, 15],[1]], to generate a signal with 2 channels, where the second channel containing 1-Hz signal 20 | func : string, optional, default: sin 21 | The periodic function to generate signal, either `sin` or `cos` 22 | add_noise : float, optional, default: 0 23 | Add random noise to the signal, where `0` has no noise 24 | plot : boolean, optional, default: False 25 | Plot the generated signal 26 | include_amplitude : boolean, optional, default: False 27 | Amplitude for each frequency is included in the `frequencies` param. See Usage. 28 | normalize : boolean, optional, default: False 29 | Normalize signal between 0 to 1 30 | Returns: 31 | signal : n-d ndarray 32 | Generated signal, a numpy array of length `sampling_rate*length_seconds` 33 | Usage: 34 | >>> # 1 channel, contains 2hz 35 | >>> s = generate_signal( 36 | >>> length_seconds=4, 37 | >>> sampling_rate=100, 38 | >>> frequencies=[2], 39 | >>> plot=True 40 | >>> ) 41 | >>> 42 | >>> # 1 channel, 2 frequencies, 1hz and 2hz, with noise 43 | >>> s = generate_signal( 44 | >>> length_seconds=4, 45 | >>> sampling_rate=100, 46 | >>> frequencies=[1,2], 47 | >>> func="cos", 48 | >>> add_noise=0.5, 49 | >>> plot=True 50 | >>> ) 51 | >>> 52 | >>> # 3 channels 53 | >>> s = generate_signal( 54 | >>> length_seconds=3.5, 55 | >>> sampling_rate=100, 56 | >>> frequencies=[[1,2],[1],[2]], 57 | >>> plot=True 58 | >>> ) 59 | >>> 60 | >>> # single channel containing 1hz and 2hz frequencies, where 1hz's amplitude is 1, and 2hz's amplitude is 4 61 | >>> s = generate_signal( 62 | >>> length_seconds=3.5, 63 | >>> sampling_rate=100, 64 | >>> frequencies=[[6,1],[2,4]], 65 | >>> plot=True, 66 | >>> include_amplitude=True, 67 | >>> ) 68 | >>> 69 | >>> # 2-dim channels. First channel contains 1hz (1 amplitude), and 2hz (4 amplitude). Second channel contains 4hz (10 amplitude), 8hz (1 amplitude) and 10hz (4 amplitude). 70 | >>> s = generate_signal( 71 | >>> length_seconds=3.5, 72 | >>> sampling_rate=100, 73 | >>> frequencies=[ [[1,1],[2,4]], [[4,10],[8,1],[10,4]] ], 74 | >>> plot=True, 75 | >>> include_amplitude=True, 76 | >>> ) 77 | """ 78 | 79 | frequencies = np.array(frequencies, dtype=object) 80 | assert len(frequencies.shape) == 1 or len(frequencies.shape) == 2 or len(frequencies.shape) == 3, "frequencies must be 1d, 2d ore 3d python list" 81 | 82 | expanded = False 83 | if isinstance(frequencies[0], int): 84 | frequencies = np.expand_dims(frequencies, axis=0) 85 | expanded = True 86 | 87 | if not include_amplitude: 88 | frequencies = np.expand_dims(frequencies, axis=-1) 89 | 90 | if len(frequencies.shape) == 2 and include_amplitude: 91 | frequencies = np.expand_dims(frequencies, axis=0) 92 | expanded = True 93 | 94 | sampling_rate = int(sampling_rate) 95 | npnts = int(sampling_rate*length_seconds) # number of time samples 96 | time = np.arange(0, npnts)/sampling_rate 97 | signal = np.zeros((frequencies.shape[0],npnts)) 98 | 99 | for channel in range(0,frequencies.shape[0]): 100 | for this_freq in frequencies[channel]: 101 | 102 | freq_signal = None 103 | 104 | if func == "cos": 105 | freq_signal = np.cos(2*np.pi*this_freq[0]*time) 106 | else: 107 | freq_signal = np.sin(2*np.pi*this_freq[0]*time) 108 | 109 | if include_amplitude: 110 | freq_signal = freq_signal * this_freq[1] 111 | 112 | signal[channel] = signal[channel] + freq_signal 113 | 114 | if normalize: 115 | # normalize 116 | max = np.repeat(signal[channel].max()[np.newaxis], npnts) 117 | min = np.repeat(signal[channel].min()[np.newaxis], npnts) 118 | signal[channel] = (2*(signal[channel]-min)/(max-min))-1 119 | 120 | if add_noise: 121 | noise = np.random.uniform(low=0, high=add_noise, size=(frequencies.shape[0],npnts)) 122 | signal = signal + noise 123 | 124 | if plot: 125 | plt.plot(time, signal.T) 126 | plt.title('Signal with sampling rate of '+str(sampling_rate)+', lasting '+str(length_seconds)+'-seconds') 127 | plt.xlabel('Time (sec.)') 128 | plt.ylabel('Amplitude') 129 | plt.show() 130 | 131 | if expanded: 132 | signal = signal[0] 133 | 134 | return signal 135 | 136 | 137 | if __name__ == "__main__": 138 | 139 | # 1 channel, contains 2hz 140 | s = generate_signal( 141 | length_seconds=4, 142 | sampling_rate=100, 143 | frequencies=[2], 144 | plot=True 145 | ) 146 | 147 | # 1 channel, 2 frequencies, 1hz and 2hz 148 | s = generate_signal( 149 | length_seconds=4, 150 | sampling_rate=100, 151 | frequencies=[1,2], 152 | func="cos", 153 | add_noise=0.5, 154 | plot=True 155 | ) 156 | 157 | # 3 channels 158 | s = generate_signal( 159 | length_seconds=3.5, 160 | sampling_rate=100, 161 | frequencies=[[1,2],[1],[2]], 162 | plot=True 163 | ) 164 | 165 | # single channel containing 1hz and 2hz frequencies, where 1hz's amplitude is 1, and 2hz's amplitude is 4 166 | s = generate_signal( 167 | length_seconds=3.5, 168 | sampling_rate=100, 169 | frequencies=[[6,1],[2,4]], 170 | plot=True, 171 | include_amplitude=True, 172 | ) 173 | 174 | # 2-dim channels. First channel contains 1hz (1 amplitude), and 2hz (4 amplitude). Second channel contains 4hz (10 amplitude), 8hz (1 amplitude) and 10hz (4 amplitude). 175 | s = generate_signal( 176 | length_seconds=3.5, 177 | sampling_rate=100, 178 | frequencies=[ [[1,1],[2,4]], [[4,10],[8,1],[10,4]] ], 179 | plot=True, 180 | include_amplitude=True, 181 | ) 182 | -------------------------------------------------------------------------------- /splearn/data/jfpm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.io as sio 4 | from typing import Tuple 5 | 6 | from splearn.data.pytorch_dataset import PyTorchDataset 7 | 8 | 9 | class JFPM(PyTorchDataset): 10 | """ 11 | A Comparison Study of Canonical Correlation Analysis Based Methods for Detecting Steady-State Visual Evoked Potentials 12 | Masaki Nakanishi, Yijun Wang, Yu-Te Wang and Tzyy-Ping Jung 13 | PLoS One, vol.10, no.10, e140703, 2015. http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0140703 14 | 15 | This dataset contains 12-class joint frequency-phase modulated steady-state visual evoked potentials (SSVEPs) acquired from 10 subjects used to estimate an online performance of brain-computer interface (BCI) in the reference study (Nakanishi et al., 2015). 16 | 17 | * Number of targets : 12 18 | * Number of channels : 8 19 | * Number of sampling points : 1114 20 | * Number of trials : 15 21 | * Sampling rate [Hz] : 256 22 | 23 | The order of the stimulus frequencies in the EEG data: 24 | [9.25, 11.25, 13.25, 9.75, 11.75, 13.75, 10.25, 12.25, 14.25, 10.75, 12.75, 14.75] Hz 25 | 26 | Download the data: https://github.com/mnakanishi/12JFPM_SSVEP 27 | """ 28 | 29 | def __init__(self, root: str, subject_id: int, verbose: bool = False, file_prefix='S') -> None: 30 | 31 | self.root = root 32 | self.sampling_rate = 256 33 | self.data, self.targets = _load_data(self.root, subject_id, verbose, file_prefix) 34 | self.stimulus_frequencies = np.array([9.25, 11.25, 13.25, 9.75, 11.75, 13.75, 10.25, 12.25, 14.25, 10.75, 12.75, 14.75]) 35 | self.targets_frequencies = self.stimulus_frequencies[self.targets] 36 | 37 | def __getitem__(self, n: int) -> Tuple[np.ndarray, int]: 38 | return (self.data[n], self.targets[n]) 39 | 40 | def __len__(self) -> int: 41 | return len(self.data) 42 | 43 | 44 | def _load_data(root, subject_id, verbose, file_prefix='s'): 45 | 46 | path = os.path.join(root, file_prefix+str(subject_id)+'.mat') 47 | data_mat = sio.loadmat(path) 48 | 49 | raw_data = data_mat['eeg'].copy() 50 | 51 | num_classes = raw_data.shape[0] 52 | num_chan = raw_data.shape[1] 53 | num_trials = raw_data.shape[3] 54 | sample_rate = 256 55 | 56 | trial_len = int(38+0.135*sample_rate+4*sample_rate) - int(38+0.135*sample_rate) 57 | 58 | filtered_data = np.zeros((num_classes, num_chan, trial_len, num_trials)) 59 | 60 | for target in range(0, num_classes): 61 | for channel in range(0, num_chan): 62 | for trial in range(0, num_trials): 63 | signal_to_filter = np.squeeze(raw_data[target, channel, int(38+0.135*sample_rate): 64 | int(38+0.135*sample_rate+4*sample_rate), 65 | trial]) 66 | filtered_data[target, channel, :, trial] = signal_to_filter 67 | 68 | filtered_data = np.transpose(filtered_data, (0,3,1,2)) 69 | 70 | data = [] 71 | targets = [] 72 | for target_id in np.arange(num_classes): 73 | data.extend(filtered_data[target_id]) 74 | this_target = np.array([target_id]*num_trials) 75 | targets.extend(this_target) 76 | 77 | data = np.array(data) 78 | targets = np.array(targets) 79 | 80 | if verbose: 81 | print('Load path:', path) 82 | print('Data shape', data.shape) 83 | print('Targets shape', targets.shape) 84 | 85 | return data, targets 86 | -------------------------------------------------------------------------------- /splearn/data/multiple_subjects.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.model_selection import StratifiedKFold 3 | from splearn.data.pytorch_dataset import PyTorchDataset 4 | 5 | 6 | class MultipleSubjects(PyTorchDataset): 7 | def __init__( 8 | self, 9 | dataset: PyTorchDataset, 10 | root: str, 11 | subject_ids: [], 12 | func_preprocessing=None, 13 | func_get_train_val_test_dataset=None, 14 | verbose: bool = False, 15 | ) -> None: 16 | 17 | self.root = root 18 | self.subject_ids = subject_ids 19 | 20 | self._load_multiple(root, dataset, subject_ids, func_preprocessing, verbose) 21 | self.targets_frequencies = self.stimulus_frequencies[self.targets] 22 | 23 | self.func_get_train_val_test_dataset = func_get_train_val_test_dataset 24 | 25 | def _load_multiple(self, root, dataset: PyTorchDataset, subject_ids: [], func_preprocessing, verbose: bool = False) -> None: 26 | is_first = True 27 | 28 | for subject_i in range(len(subject_ids)): 29 | 30 | subject_id = subject_ids[subject_i] 31 | print('Load subject:', subject_id) 32 | 33 | subject_dataset = dataset(root=root, subject_id=subject_id) 34 | 35 | sub_data = subject_dataset.data 36 | sub_targets = subject_dataset.targets 37 | 38 | if is_first: 39 | self.data = np.zeros((len(subject_ids), sub_data.shape[0], sub_data.shape[1], sub_data.shape[2])) 40 | self.targets = np.zeros((len(subject_ids), sub_targets.shape[0])) 41 | self.sampling_rate = subject_dataset.sampling_rate 42 | self.stimulus_frequencies = subject_dataset.stimulus_frequencies 43 | self.channel_names = subject_dataset.channel_names if hasattr(subject_dataset, 'channel_names') else None 44 | is_first = False 45 | 46 | self.data[subject_i, :, :, :] = sub_data 47 | self.targets[subject_i] = sub_targets 48 | 49 | self.targets = self.targets.astype(np.int32) 50 | 51 | if func_preprocessing is not None: 52 | func_preprocessing(self) 53 | 54 | def set_data(self, x): 55 | self.data = x 56 | 57 | def set_targets(self, targets): 58 | self.targets = targets 59 | 60 | def get_subject(self, subject_id): 61 | index = list(self.subject_ids).index(subject_id) 62 | return self.data[index], self.targets[index] 63 | 64 | def dataset_split_stratified(self, X, y, k=0, n_splits=3, seed=71, shuffle=True): 65 | skf = StratifiedKFold(n_splits=n_splits, random_state=seed, shuffle=shuffle) 66 | split_data = skf.split(X, y) 67 | 68 | for idx, value in enumerate(split_data): 69 | if k != idx: 70 | continue 71 | else: 72 | train_index, test_index = value 73 | X_train, X_test = X[train_index], X[test_index] 74 | y_train, y_test = y[train_index], y[test_index] 75 | return (X_train, y_train), (X_test, y_test) 76 | 77 | def get_train_val_test_dataset(self, **kwargs): 78 | if self.func_get_train_val_test_dataset is None: 79 | return self._leave_one_subject_out(**kwargs) 80 | else: 81 | return self.func_get_train_val_test_dataset(self, **kwargs) 82 | 83 | def _leave_one_subject_out(self, **kwargs): 84 | 85 | test_subject_id = kwargs["test_subject_id"] if "test_subject_id" in kwargs else 1 86 | kfold_k = kwargs["kfold_k"] if "kfold_k" in kwargs else 0 87 | kfold_split = kwargs["kfold_split"] if "kfold_split" in kwargs else 3 88 | 89 | # get test data 90 | # test_sub_idx = self.subject_ids.index(test_subject_id) 91 | test_sub_idx = np.where(self.subject_ids == test_subject_id)[0][0] 92 | selected_subject_data = self.data[test_sub_idx] 93 | selected_subject_targets = self.targets[test_sub_idx] 94 | test_dataset = PyTorchDataset(selected_subject_data, selected_subject_targets) 95 | 96 | # get train val data 97 | indices = np.arange(self.data.shape[0]) 98 | train_val_data = self.data[indices!=test_sub_idx, :, :, :] 99 | train_val_data = train_val_data.reshape((train_val_data.shape[0]*train_val_data.shape[1], train_val_data.shape[2], train_val_data.shape[3])) 100 | train_val_targets = self.targets[indices!=test_sub_idx, :] 101 | train_val_targets = train_val_targets.reshape((train_val_targets.shape[0]*train_val_targets.shape[1])) 102 | 103 | # train test split 104 | (X_train, y_train), (X_val, y_val) = self.dataset_split_stratified(train_val_data, train_val_targets, k=kfold_k, n_splits=kfold_split) 105 | train_dataset = PyTorchDataset(X_train, y_train) 106 | val_dataset = PyTorchDataset(X_val, y_val) 107 | 108 | return train_dataset, val_dataset, test_dataset 109 | 110 | def get_train_test_dataset(self, **kwargs): 111 | 112 | test_subject_id = kwargs["test_subject_id"] if "test_subject_id" in kwargs else 1 113 | 114 | # get test data 115 | test_sub_idx = np.where(self.subject_ids == test_subject_id)[0][0] 116 | selected_subject_data = self.data[test_sub_idx] 117 | selected_subject_targets = self.targets[test_sub_idx] 118 | test_dataset = PyTorchDataset(selected_subject_data, selected_subject_targets) 119 | 120 | # get train data 121 | indices = np.arange(self.data.shape[0]) 122 | X_train = self.data[indices!=test_sub_idx, :, :, :] 123 | X_train = X_train.reshape((X_train.shape[0]*X_train.shape[1], X_train.shape[2], X_train.shape[3])) 124 | y_train = self.targets[indices!=test_sub_idx, :] 125 | y_train = y_train.reshape((y_train.shape[0]*y_train.shape[1])) 126 | train_dataset = PyTorchDataset(X_train, y_train) 127 | return train_dataset, test_dataset 128 | -------------------------------------------------------------------------------- /splearn/data/openbmi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.io as sio 4 | from typing import Tuple 5 | 6 | from splearn.data.pytorch_dataset import PyTorchDataset 7 | 8 | 9 | class OPENBMI(PyTorchDataset): 10 | """ 11 | EEG dataset and OpenBMI toolbox for three BCI paradigms: an investigation into BCI illiteracy. 12 | Min-Ho Lee, O-Yeon Kwon, Yong-Jeong Kim, Hong-Kyung Kim, Young-Eun Lee, John Williamson, Siamac Fazli, Seong-Whan Lee. 13 | https://academic.oup.com/gigascience/article/8/5/giz002/5304369 14 | Target frequencies: 5.45, 6.67, 8.57, 12 Hz 15 | Sampling rate: 1000 Hz 16 | """ 17 | 18 | def __init__(self, root: str, subject_id: int, session: int, verbose: bool = False) -> None: 19 | 20 | self.root = root 21 | self.sampling_rate = 1000 22 | 23 | self.data, self.targets, self.channel_names = _load_data( 24 | self.root, subject_id, session, verbose) 25 | 26 | self.stimulus_frequencies = np.array([12.0,8.57,6.67,5.45]) 27 | self.targets_frequencies = self.stimulus_frequencies[self.targets] 28 | 29 | def __getitem__(self, n: int) -> Tuple[np.ndarray, int]: 30 | return (self.data[n], self.targets[n]) 31 | 32 | def __len__(self) -> int: 33 | return len(self.data) 34 | 35 | 36 | def _load_data(root, subject_id, session, verbose): 37 | 38 | path = os.path.join(root, 'session'+str(session), 39 | 's'+str(subject_id)+'/EEG_SSVEP.mat') 40 | 41 | data_mat = sio.loadmat(path) 42 | 43 | objects_in_mat = [] 44 | for i in data_mat['EEG_SSVEP_train'][0][0]: 45 | objects_in_mat.append(i) 46 | 47 | # data 48 | data = objects_in_mat[0][:, :, :].copy() 49 | data = np.transpose(data, (1, 2, 0)) 50 | data = data.astype(np.float32) 51 | 52 | # label 53 | targets = [] 54 | for i in range(data.shape[0]): 55 | targets.append([objects_in_mat[2][0][i], 0, objects_in_mat[4][0][i]]) 56 | targets = np.array(targets) 57 | targets = targets[:, 2] 58 | targets = targets-1 59 | 60 | # channel 61 | channel_names = [v[0] for v in objects_in_mat[8][0]] 62 | 63 | if verbose: 64 | print('Load path:', path) 65 | print('Objects in .mat', len(objects_in_mat), 66 | data_mat['EEG_SSVEP_train'].dtype.descr) 67 | print() 68 | print('Data shape', data.shape) 69 | print('Targets shape', targets.shape) 70 | 71 | return data, targets, channel_names 72 | -------------------------------------------------------------------------------- /splearn/data/pytorch_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | 4 | 5 | class PyTorchDataset(Dataset): 6 | def __init__(self, data, targets): 7 | self.data = data 8 | self.data = self.data.astype(np.float32) 9 | self.targets = targets 10 | self.channel_names = None 11 | 12 | def __getitem__(self, index): 13 | return self.data[index], self.targets[index] 14 | 15 | def __len__(self): 16 | return len(self.data) 17 | 18 | def set_data_targets(self, data: [] = None, targets: [] = None) -> None: 19 | if data is not None: 20 | self.data = data.copy() 21 | if targets is not None: 22 | self.targets = targets.copy() 23 | self.targets = self.targets.astype(int) 24 | 25 | def set_channel_names(self,channel_names): 26 | self.channel_names = channel_names 27 | 28 | def get_data(self): 29 | r""" 30 | Data shape: (6, 40, 9, 1250) [# of blocks, # of targets, # of channels, # of sampling points] 31 | """ 32 | return self.data 33 | 34 | def get_targets(self): 35 | r""" 36 | Targets index from 0 to 39. Shape: (6, 40) [# of blocks, # of targets] 37 | """ 38 | return self.targets 39 | 40 | def get_stimulus_frequencies(self): 41 | r""" 42 | A list of frequencies of each stimulus: 43 | [8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,8.2,9.2,10.2,11.2,12.2,13.2,14.2,15.2,8.4,9.4,10.4,11.4,12.4,13.4,14.4,15.4,8.6,9.6,10.6,11.6,12.6,13.6,14.6,15.6,8.8,9.8,10.8,11.8,12.8,13.8,14.8,15.8] 44 | """ 45 | return self.stimulus_frequencies 46 | 47 | def get_targets_frequencies(self): 48 | r""" 49 | Targets by frequencies, range between 8.0 Hz to 15.8 Hz. 50 | Shape: (6, 40) [# of blocks, # of targets] 51 | """ 52 | return self.targets_frequencies 53 | 54 | class PyTorchDataset2Views(Dataset): 55 | def __init__(self, data_view1, data_view2, targets): 56 | self.data_view1 = data_view1.astype(np.float32) 57 | self.data_view2 = data_view2.astype(np.float32) 58 | self.targets = targets 59 | 60 | def __getitem__(self, index): 61 | return self.data_view1[index], self.data_view2[index], self.targets[index] 62 | 63 | def __len__(self): 64 | return len(self.data_view1) -------------------------------------------------------------------------------- /splearn/data/sample/ssvep.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinglescode/python-signal-processing/77e448a163caffabf36a81f839871def5441214e/splearn/data/sample/ssvep.mat -------------------------------------------------------------------------------- /splearn/data/sample_ssvep.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """A 40-target SSVEP dataset recorded from a single subject. 3 | """ 4 | import numpy as np 5 | from scipy.io import loadmat 6 | import os 7 | 8 | 9 | class SampleSSVEPData(): 10 | r""" 11 | A 40-target SSVEP dataset recorded from a single subject. 12 | 13 | Data description: 14 | Original Data shape : (40, 9, 1250, 6) [# of targets, # of channels, # of sampling points, # of blocks] 15 | Stimulus frequencies : 8.0 - 15.8 Hz with an interval of 0.2 Hz 16 | Stimulus phases : 0pi, 0.5pi, 1.0pi, and 1.5pi 17 | Number of channels : 9 (1: Pz, 2: PO5,3: PO3, 4: POz, 5: PO4, 6: PO6, 7: O1, 8: Oz, and 9: O2) 18 | Number of recording blocks : 6 19 | Length of an epoch : 5 seconds 20 | Sampling rate : 250 Hz 21 | Args: 22 | path: str, default: None 23 | Path to ssvepdata.mat file 24 | Usage: 25 | >>> from splearn.data import SampleSSVEPData 26 | >>> 27 | >>> data = SampleSSVEPData() 28 | >>> eeg = data.get_data() 29 | >>> labels = data.get_targets() 30 | >>> print("eeg.shape:", eeg.shape) 31 | >>> print("labels.shape:", labels.shape) 32 | Reference: 33 | https://www.pnas.org/content/early/2015/10/14/1508080112.abstract 34 | """ 35 | def __init__(self, path=None): 36 | if path is None: 37 | path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sample") 38 | 39 | # Get EEG data 40 | data = loadmat(os.path.join(path,"ssvep.mat")) 41 | data = data["eeg"] 42 | data = data.transpose([3,0,1,2]) 43 | self.data = data 44 | 45 | # Prepare targets 46 | n_blocks, n_targets, n_channels, n_samples = self.data.shape 47 | targets = np.tile(np.arange(0, n_targets+0), (1, n_blocks)) 48 | targets = targets.reshape((n_blocks, n_targets)) 49 | self.targets = targets 50 | 51 | # Prepare targets frequencies 52 | self.stimulus_frequencies = np.array([8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,8.2,9.2,10.2,11.2,12.2,13.2,14.2,15.2,8.4,9.4,10.4,11.4,12.4,13.4,14.4,15.4,8.6,9.6,10.6,11.6,12.6,13.6,14.6,15.6,8.8,9.8,10.8,11.8,12.8,13.8,14.8,15.8]) 53 | 54 | targets_frequencies = np.tile(self.stimulus_frequencies, (1, n_blocks)) 55 | targets_frequencies = targets_frequencies.reshape((n_blocks, n_targets)) 56 | self.targets_frequencies = targets_frequencies 57 | 58 | self.sampling_rate = 250 59 | self.channels = ["Pz", "PO5","PO3", "POz", "PO4", "PO6", "O1", "Oz", "O2"] 60 | 61 | def get_data(self): 62 | r""" 63 | Data shape: (6, 40, 9, 1250) [# of blocks, # of targets, # of channels, # of sampling points] 64 | """ 65 | return self.data 66 | 67 | def get_targets(self): 68 | r""" 69 | Targets index from 0 to 39. Shape: (6, 40) [# of blocks, # of targets] 70 | """ 71 | return self.targets 72 | 73 | def get_stimulus_frequencies(self): 74 | r""" 75 | A list of frequencies of each stimulus: 76 | [8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,8.2,9.2,10.2,11.2,12.2,13.2,14.2,15.2,8.4,9.4,10.4,11.4,12.4,13.4,14.4,15.4,8.6,9.6,10.6,11.6,12.6,13.6,14.6,15.6,8.8,9.8,10.8,11.8,12.8,13.8,14.8,15.8] 77 | """ 78 | return self.stimulus_frequencies 79 | 80 | def get_targets_frequencies(self): 81 | r""" 82 | Targets by frequencies, range between 8.0 Hz to 15.8 Hz. 83 | Shape: (6, 40) [# of blocks, # of targets] 84 | """ 85 | return self.targets_frequencies 86 | 87 | 88 | if __name__ == "__main__": 89 | from splearn.data.sample_ssvep import SampleSSVEPData 90 | 91 | data = SampleSSVEPData() 92 | eeg = data.get_data() 93 | labels = data.get_targets() 94 | print("eeg.shape:", eeg.shape) 95 | print("labels.shape:", labels.shape) 96 | -------------------------------------------------------------------------------- /splearn/data/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def onehot_targets(targets): 4 | return (np.arange(targets.max()+1) == targets[...,None]).astype(int) -------------------------------------------------------------------------------- /splearn/filter/__init__.py: -------------------------------------------------------------------------------- 1 | from .channels import pick_channels 2 | from .butterworth import butter_bandpass_filter 3 | from .notch import notch_filter -------------------------------------------------------------------------------- /splearn/filter/cca_spatial_filtering.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Use CCA for spatial filtering to improve the signal. 3 | """ 4 | import numpy as np 5 | from splearn.cross_decomposition.cca import perform_cca 6 | 7 | 8 | def cca_spatial_filtering(signal, reference_frequencies): 9 | r""" 10 | Use CCA for spatial filtering is to find a spatial filter that maximizes the correlation between the spatially filtered signal and the average evoked response, thereby improving the signal-to-noise ratio of the filtered signal on a single-trial basis. 11 | Read more: https://github.com/jinglescode/papers/issues/90, https://github.com/jinglescode/papers/issues/89 12 | Args: 13 | signal : ndarray, shape (trial,channel,time) 14 | Input signal in time domain 15 | reference_frequencies : ndarray, shape (len(flick_freq),2*num_harmonics,time) 16 | Required sinusoidal reference templates corresponding to the flicker frequency for SSVEP classification 17 | Returns: 18 | filtered_signal : ndarray, shape (reference_frequencies.shape[0],signal.shape[0],signal.shape[1],signal.shape[2]) 19 | Signal after spatial filter 20 | Dependencies: 21 | np : numpy package 22 | perform_cca : function 23 | """ 24 | _, _, _, wx, _ = perform_cca(signal, reference_frequencies) 25 | filtered_signal = np.zeros((reference_frequencies.shape[0], signal.shape[0], signal.shape[1], signal.shape[2])) 26 | 27 | swapped_s = np.swapaxes(x_train, 1, 2) 28 | 29 | for target_i in range(reference_frequencies.shape[0]): 30 | for trial_i in range(swapped_s.shape[0]): 31 | t_trial = swapped_s[trial_i] 32 | t_w = wx[trial_i,target_i,:] 33 | filtered_s = np.matmul(t_trial, t_w) 34 | filtered_signal[target_i,trial_i,:,:] = filtered_s 35 | 36 | return filtered_signal 37 | -------------------------------------------------------------------------------- /splearn/filter/channels.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def pick_channels(data: np.ndarray, 5 | channel_names: [str], 6 | selected_channels: [str], 7 | verbose: bool = False) -> np.ndarray: 8 | 9 | picked_ch = pick_channels_mne(channel_names, selected_channels) 10 | 11 | if len(data.shape) == 3: 12 | data = data[:, picked_ch, :] 13 | if len(data.shape) == 4: 14 | data = data[:, :, picked_ch, :] 15 | 16 | if verbose: 17 | print('picking channels: channel_names', 18 | len(channel_names), channel_names) 19 | print('picked_ch', picked_ch) 20 | print() 21 | 22 | del picked_ch 23 | 24 | return data 25 | 26 | 27 | def pick_channels_mne(ch_names, include, exclude=[], ordered=False): 28 | """Pick channels by names. 29 | Returns the indices of ``ch_names`` in ``include`` but not in ``exclude``. 30 | Taken from https://github.com/mne-tools/mne-python/blob/master/mne/io/pick.py 31 | Parameters 32 | ---------- 33 | ch_names : list of str 34 | List of channels. 35 | include : list of str 36 | List of channels to include (if empty include all available). 37 | .. note:: This is to be treated as a set. The order of this list 38 | is not used or maintained in ``sel``. 39 | exclude : list of str 40 | List of channels to exclude (if empty do not exclude any channel). 41 | Defaults to []. 42 | ordered : bool 43 | If true (default False), treat ``include`` as an ordered list 44 | rather than a set, and any channels from ``include`` are missing 45 | in ``ch_names`` an error will be raised. 46 | .. versionadded:: 0.18 47 | Returns 48 | ------- 49 | sel : array of int 50 | Indices of good channels. 51 | See Also 52 | -------- 53 | pick_channels_regexp, pick_types 54 | """ 55 | if len(np.unique(ch_names)) != len(ch_names): 56 | raise RuntimeError('ch_names is not a unique list, picking is unsafe') 57 | # _check_excludes_includes(include) 58 | # _check_excludes_includes(exclude) 59 | if not ordered: 60 | if not isinstance(include, set): 61 | include = set(include) 62 | if not isinstance(exclude, set): 63 | exclude = set(exclude) 64 | sel = [] 65 | for k, name in enumerate(ch_names): 66 | if (len(include) == 0 or name in include) and name not in exclude: 67 | sel.append(k) 68 | else: 69 | if not isinstance(include, list): 70 | include = list(include) 71 | if len(include) == 0: 72 | include = list(ch_names) 73 | if not isinstance(exclude, list): 74 | exclude = list(exclude) 75 | sel, missing = list(), list() 76 | for name in include: 77 | if name in ch_names: 78 | if name not in exclude: 79 | sel.append(ch_names.index(name)) 80 | else: 81 | missing.append(name) 82 | if len(missing): 83 | raise ValueError('Missing channels from ch_names required by ' 84 | 'include:\n%s' % (missing,)) 85 | return np.array(sel, int) -------------------------------------------------------------------------------- /splearn/filter/notch.py: -------------------------------------------------------------------------------- 1 | from scipy.signal import filtfilt, iirnotch 2 | 3 | def notch_filter(data, sampling_rate=1000, notch_freq=50.0, quality_factor=30.0): 4 | b_notch, a_notch = iirnotch(notch_freq, quality_factor, sampling_rate) 5 | data_notched = filtfilt(b_notch, a_notch, data) 6 | return data_notched -------------------------------------------------------------------------------- /splearn/fourier.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Fourier analysis is a method for expressing a function as a sum of periodic components, and for recovering the signal from those components. 3 | """ 4 | import numpy as np 5 | from scipy.fft import fft 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def fast_fourier_transform(signal, sampling_rate, plot=False, **kwargs): 10 | r""" 11 | Use Fourier transforms to find the frequency components of a signal buried in noise. 12 | 13 | Args: 14 | signal : ndarray, shape (time,) or (channel,time) or (trial,channel,time) 15 | Single input signal in time domain 16 | sampling_rate: int 17 | Sampling frequency 18 | plot : boolean, optional, default: False 19 | To plot the single-sided amplitude spectrum 20 | plot_xlim : array of shape [lower, upper], optional, default: [0, int(`sampling_rate`/2)] 21 | If `plot=True`, set a limit on the X-axis between lower and upper bound 22 | plot_ylim : array of shape [lower, upper], optional, default: None 23 | If `plot=True`, set a limit on the Y-axis between lower and upper bound 24 | plot_label : string, optional, default: '' 25 | If `plot=True`, text label for this signal in plot, shown in legend 26 | plot_line_freq : int or float or list, option, default: None 27 | If `plot=True`, plot a vertical line to mark the target frequency. If a list is given, will plot multiple lines. 28 | Returns: 29 | P1 : ndarray 30 | Frequency domain. Compute the two-sided spectrum P2. Then compute the single-sided spectrum P1 based on P2 and the even-valued signal length L. 31 | See https://www.mathworks.com/help/matlab/ref/fft.html 32 | Usage: 33 | >>> from splearn.data.generate import generate_signal 34 | >>> from splearn.fourier import fast_fourier_transform 35 | >>> 36 | >>> s1 = generate_signal( 37 | >>> length_seconds=3.5, 38 | >>> sampling_rate=100, 39 | >>> frequencies=[4,7], 40 | >>> plot=True 41 | >>> ) 42 | >>> 43 | >>> p1 = fast_fourier_transform( 44 | >>> signal=s1, 45 | >>> sampling_rate=100, 46 | >>> plot=True, 47 | >>> plot_xlim=[0, 10], 48 | >>> plot_line_freq=7 49 | >>> ) 50 | Reference: 51 | - https://www.mathworks.com/help/matlab/ref/fft.html 52 | - https://docs.scipy.org/doc/scipy/reference/tutorial/fft.html 53 | - https://docs.scipy.org/doc/scipy/reference/generated/scipy.fft.fft.html 54 | """ 55 | 56 | plot_xlim = kwargs['plot_xlim'] if 'plot_xlim' in kwargs else [0, int(sampling_rate/2)] 57 | plot_ylim = kwargs['plot_ylim'] if 'plot_ylim' in kwargs else None 58 | plot_label = kwargs['plot_label'] if 'plot_label' in kwargs else '' 59 | plot_line_freq = kwargs['plot_line_freq'] if 'plot_line_freq' in kwargs else None 60 | plot_label = kwargs['plot_label'] if 'plot_label' in kwargs else '' 61 | 62 | fft_p1 = None 63 | 64 | if len(signal.shape) == 1: 65 | fft_p1 = _fast_fourier_transform(signal, sampling_rate) 66 | fft_p1 = np.expand_dims(fft_p1,0) 67 | 68 | if len(signal.shape) == 2: 69 | for ch in range(signal.shape[0]): 70 | fft_c = _fast_fourier_transform(signal[ch, :], sampling_rate=sampling_rate) 71 | 72 | if fft_p1 is None: 73 | fft_p1 = np.zeros((signal.shape[0], fft_c.shape[0])) 74 | 75 | fft_p1[ch] = fft_c 76 | 77 | if len(signal.shape) == 3: 78 | for trial in range(signal.shape[0]): 79 | for ch in range(signal.shape[1]): 80 | fft_c = _fast_fourier_transform(signal[trial, ch, :], sampling_rate=sampling_rate) 81 | 82 | if fft_p1 is None: 83 | fft_p1 = np.zeros((signal.shape[0], signal.shape[1], fft_c.shape[0])) 84 | 85 | fft_p1[trial,ch,:] = fft_c 86 | 87 | if plot: 88 | signal_length = signal.shape[ len(signal.shape)-1 ] 89 | f = sampling_rate*np.arange(0, (signal_length/2)+1)/signal_length 90 | 91 | if len(fft_p1.shape) == 3: 92 | means = np.mean(fft_p1, 0) 93 | stds = np.std(fft_p1, 0) 94 | for c in range(fft_p1.shape[1]): 95 | plt.plot(f, means[c], label=plot_label) 96 | plt.xlim(plot_xlim) 97 | plt.fill_between(f, means[c]-stds[c],means[c]+stds[c],alpha=.1) 98 | else: 99 | for c in range(fft_p1.shape[0]): 100 | plt.plot(f, fft_p1[c], label=plot_label) 101 | plt.xlim(plot_xlim) 102 | 103 | if plot_ylim is not None: 104 | plt.ylim(plot_ylim) 105 | 106 | if plot_label != '': 107 | plt.legend() 108 | 109 | if plot_line_freq is not None: 110 | if isinstance(plot_line_freq, list): 111 | for i in plot_line_freq: 112 | plt.axvline(x=i, color='r', linewidth=1.5) 113 | else: 114 | plt.axvline(x=plot_line_freq, color='r', linewidth=1.5) 115 | 116 | if len(signal.shape) == 1: 117 | fft_p1 = fft_p1[0] 118 | 119 | return fft_p1 120 | 121 | def _fast_fourier_transform(signal, sampling_rate): 122 | r""" 123 | Use Fourier transforms to find the frequency components of a signal buried in noise. 124 | 125 | Args: 126 | signal : ndarray, shape (time,) 127 | Single input signal in time domain 128 | sampling_rate: int 129 | Sampling frequency 130 | Returns: 131 | P1 : ndarray 132 | Frequency domain. Compute the two-sided spectrum P2. Then compute the single-sided spectrum P1 based on P2 and the even-valued signal length L. 133 | See https://www.mathworks.com/help/matlab/ref/fft.html. 134 | Usage: 135 | See `fast_fourier_transform` 136 | Reference: 137 | - https://www.mathworks.com/help/matlab/ref/fft.html 138 | - https://docs.scipy.org/doc/scipy/reference/tutorial/fft.html 139 | - https://docs.scipy.org/doc/scipy/reference/generated/scipy.fft.fft.html 140 | """ 141 | 142 | signal_length = signal.shape[0] 143 | 144 | if signal_length % 2 != 0: 145 | signal_length = signal_length+1 146 | 147 | y = fft(signal) 148 | p2 = np.abs(y/signal_length) 149 | p1 = p2[0:round(signal_length/2+1)] 150 | p1[1:-1] = 2*p1[1:-1] 151 | 152 | return p1 153 | 154 | 155 | if __name__ == "__main__": 156 | 157 | from splearn.data.generate import generate_signal 158 | from splearn.fourier import fast_fourier_transform 159 | 160 | s1 = generate_signal( 161 | length_seconds=3.5, 162 | sampling_rate=100, 163 | frequencies=[4,7], 164 | plot=True 165 | ) 166 | 167 | p1 = fast_fourier_transform( 168 | signal=s1, 169 | sampling_rate=100, 170 | plot=True, 171 | plot_xlim=[0, 10], 172 | plot_line_freq=7 173 | ) 174 | -------------------------------------------------------------------------------- /splearn/nn/base/__init__.py: -------------------------------------------------------------------------------- 1 | from splearn.nn.base.lightning import LightningModel 2 | from splearn.nn.base.classifier import LightningModelClassifier -------------------------------------------------------------------------------- /splearn/nn/base/classifier.py: -------------------------------------------------------------------------------- 1 | import torchmetrics 2 | from splearn.nn.base import LightningModel 3 | from splearn.nn.loss import LabelSmoothCrossEntropyLoss 4 | 5 | 6 | class LightningModelClassifier(LightningModel): 7 | def __init__( 8 | self, 9 | optimizer="adamw", 10 | scheduler="cosine_with_warmup", 11 | optimizer_learning_rate: float=1e-3, 12 | optimizer_epsilon: float=1e-6, 13 | optimizer_weight_decay: float=0.0005, 14 | scheduler_warmup_epochs: int=10, 15 | ): 16 | super().__init__() 17 | self.save_hyperparameters() 18 | 19 | self.train_acc = torchmetrics.Accuracy() 20 | self.valid_acc = torchmetrics.Accuracy() 21 | self.test_acc = torchmetrics.Accuracy() 22 | 23 | self.criterion_classifier = LabelSmoothCrossEntropyLoss(smoothing=0.3) # F.cross_entropy() 24 | 25 | def build_model(self, model): 26 | self.model = model 27 | 28 | def forward(self, x): 29 | y_hat = self.model(x) 30 | return y_hat 31 | 32 | def step(self, batch, batch_idx): 33 | x, y = batch 34 | y_hat = self.forward(x) 35 | loss = self.criterion_classifier(y_hat, y.long()) # self.criterion_classifier(y_hat, y.long()) # F.cross_entropy(y_hat, y.long()) 36 | return y_hat, y, loss 37 | 38 | def training_step(self, batch, batch_idx): 39 | y_hat, y, loss = self.step(batch, batch_idx) 40 | acc = self.train_acc(y_hat, y.long()) 41 | self.log('train_loss', loss, on_step=True) 42 | return loss 43 | 44 | def validation_step(self, batch, batch_idx): 45 | y_hat, y, loss = self.step(batch, batch_idx) 46 | acc = self.valid_acc(y_hat, y.long()) 47 | self.log('valid_loss', loss, on_step=True) 48 | return loss 49 | 50 | def test_step(self, batch, batch_idx): 51 | y_hat, y, loss = self.step(batch, batch_idx) 52 | acc = self.test_acc(y_hat, y.long()) 53 | self.log('test_loss', loss) 54 | return loss 55 | 56 | def training_epoch_end(self, outs): 57 | self.log('train_acc_epoch', self.train_acc.compute()) 58 | 59 | def validation_epoch_end(self, outs): 60 | self.log('valid_acc_epoch', self.valid_acc.compute()) 61 | 62 | def test_epoch_end(self, outs): 63 | self.log('test_acc_epoch', self.test_acc.compute()) -------------------------------------------------------------------------------- /splearn/nn/base/lightning.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import LightningModule 2 | from splearn.nn.optimization import get_scheduler, get_optimizer, get_num_steps 3 | 4 | 5 | class LightningModel(LightningModule): 6 | def __init__( 7 | self 8 | ): 9 | super().__init__() 10 | 11 | def forward(self, x): 12 | raise NotImplementedError 13 | 14 | def training_step(self, batch, batch_idx): 15 | raise NotImplementedError 16 | 17 | def validation_step(self, batch, batch_idx): 18 | raise NotImplementedError 19 | 20 | def test_step(self, batch, batch_idx): 21 | raise NotImplementedError 22 | 23 | def configure_optimizers(self): 24 | 25 | optimizer = get_optimizer( 26 | name=self.hparams.optimizer, 27 | model=self, 28 | lr=self.hparams.optimizer_learning_rate, 29 | weight_decay=self.hparams.optimizer_weight_decay, 30 | epsilon=self.hparams.optimizer_epsilon 31 | ) 32 | 33 | total_train_steps, num_warmup_steps = get_num_steps(self) 34 | 35 | scheduler = get_scheduler( 36 | name=self.hparams.scheduler, 37 | optimizer=optimizer, 38 | num_warmup_steps=num_warmup_steps, 39 | num_training_steps=total_train_steps, 40 | ) 41 | 42 | scheduler = {'scheduler': scheduler, 'interval': 'step', 'frequency': 1} 43 | return [optimizer], [scheduler] -------------------------------------------------------------------------------- /splearn/nn/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | LabelSmoothCrossEntropyLoss 3 | https://github.com/pytorch/pytorch/issues/7455 4 | """ 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.nn.modules.loss import _WeightedLoss 8 | 9 | 10 | class LabelSmoothCrossEntropyLoss(_WeightedLoss): 11 | def __init__(self, weight=None, reduction='mean', smoothing=0.0): 12 | super().__init__(weight=weight, reduction=reduction) 13 | self.smoothing = smoothing 14 | self.weight = weight 15 | self.reduction = reduction 16 | 17 | @staticmethod 18 | def _smooth_one_hot(targets: torch.Tensor, n_classes: int, smoothing=0.0): 19 | assert 0 <= smoothing < 1 20 | with torch.no_grad(): 21 | targets = torch.empty(size=(targets.size(0), n_classes), 22 | device=targets.device) \ 23 | .fill_(smoothing / (n_classes - 1)) \ 24 | .scatter_(1, targets.data.unsqueeze(1), 1. - smoothing) 25 | return targets 26 | 27 | def forward(self, inputs, targets): 28 | targets = LabelSmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1), 29 | self.smoothing) 30 | lsm = F.log_softmax(inputs, -1) 31 | 32 | if self.weight is not None: 33 | lsm = lsm * self.weight.unsqueeze(0) 34 | 35 | loss = -(targets * lsm).sum(-1) 36 | 37 | if self.reduction == 'sum': 38 | loss = loss.sum() 39 | elif self.reduction == 'mean': 40 | loss = loss.mean() 41 | 42 | return loss -------------------------------------------------------------------------------- /splearn/nn/models/ConvCA/ConvCA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from splearn.nn.modules.conv2d import Conv2d 5 | 6 | class SignalCNN(nn.Module): 7 | """ 8 | Convolutional correlation analysis for enhancing the performance of SSVEP-based brain-computer interface 9 | https://ieeexplore.ieee.org/abstract/document/9261605/ 10 | https://github.com/yaoli90/Conv-CA/blob/main/convca.py 11 | """ 12 | def __init__(self, num_channel=10): 13 | super().__init__() 14 | self.conv1 = Conv2d(1, 16, (9, num_channel)) 15 | self.conv2 = Conv2d(16, 1, (1, num_channel)) 16 | self.conv3 = Conv2d(1, 1, (1, num_channel), padding=(0,0)) 17 | self.dropout = nn.Dropout(p=0.75) 18 | 19 | def forward(self, x): 20 | x = self.conv1(x) 21 | x = self.conv2(x) 22 | x = self.conv3(x) 23 | x = self.dropout(x) 24 | x = torch.flatten(x, start_dim=1) 25 | return x 26 | 27 | class ReferenceCNN(nn.Module): 28 | def __init__(self, num_channel=10, num_freq=40): 29 | super().__init__() 30 | self.conv1 = Conv2d(num_channel, num_freq, (9, 1)) 31 | self.conv2 = Conv2d(num_freq, 1, (9, 1)) 32 | self.dropout = nn.Dropout(p=0.15) 33 | 34 | def forward(self, x): 35 | x = self.conv1(x) 36 | x = self.conv2(x) 37 | x = self.dropout(x) 38 | x = torch.squeeze(x) 39 | return x 40 | 41 | class Corr(nn.Module): 42 | def __init__(self): 43 | super().__init__() 44 | 45 | def batch_dot(self, x_batch, y_batch): 46 | return torch.sum(x_batch * y_batch, axis=1) 47 | 48 | def forward(self, x, t): 49 | corr_xt = self.batch_dot(x.unsqueeze(-1),t) # [?,cl] 50 | corr_xx = self.batch_dot(x,x) # [?] 51 | corr_xx = corr_xx.unsqueeze(-1) # [?,1] 52 | corr_tt = torch.sum(t*t,axis=1) # [?,cl] 53 | self.corr = corr_xt/torch.sqrt(corr_tt)/torch.sqrt(corr_xx) 54 | return self.corr 55 | 56 | class ConvCA(nn.Module): 57 | def __init__(self, num_channel=10, num_classes=4, **kwargs): 58 | super().__init__() 59 | self.signal_cnn = SignalCNN(num_channel) 60 | self.reference_cnn = ReferenceCNN(num_channel, num_classes) 61 | self.correlation = Corr() 62 | self.dense = nn.Linear(in_features=num_classes, out_features=num_classes) 63 | 64 | def forward(self, x, ref): 65 | x = x.unsqueeze(-1) 66 | x = torch.transpose(x, 3, 1) 67 | 68 | x1 = self.signal_cnn(x) 69 | x2 = self.reference_cnn(ref) 70 | x = self.correlation(x1, x2) 71 | x = self.dense(x) 72 | return x 73 | -------------------------------------------------------------------------------- /splearn/nn/models/ConvCA/ConvCaLighting.py: -------------------------------------------------------------------------------- 1 | from splearn.nn.base import LightningModelClassifier 2 | from splearn.nn.models import ConvCA 3 | from splearn.nn.utils import get_backbone_and_fc 4 | from splearn.nn.loss import LabelSmoothCrossEntropyLoss 5 | 6 | 7 | class ConvCaLighting(LightningModelClassifier): 8 | def __init__( 9 | self, 10 | optimizer="adamw", 11 | scheduler="cosine_with_warmup", 12 | optimizer_learning_rate: float=1e-3, 13 | optimizer_epsilon: float=1e-6, 14 | optimizer_weight_decay: float=0.0005, 15 | scheduler_warmup_epochs: int=10, 16 | ): 17 | super().__init__() 18 | self.save_hyperparameters() 19 | 20 | self.criterion_classifier = LabelSmoothCrossEntropyLoss(smoothing=0.3) 21 | 22 | def build_model(self, model, **kwargs): 23 | self.model = model 24 | 25 | def forward(self, x, ref): 26 | y_hat = self.model(x, ref) 27 | return y_hat 28 | 29 | def train_val_step(self, batch, batch_idx): 30 | x1, x2, y = batch 31 | y_hat = self.forward(x1, x2) 32 | loss = self.criterion_classifier(y_hat, y.long()) 33 | return y_hat, y, loss 34 | 35 | def training_step(self, batch, batch_idx): 36 | y_hat, y, loss = self.train_val_step(batch, batch_idx) 37 | acc = self.train_acc(y_hat, y.long()) 38 | self.log('train_loss', loss, on_step=True) 39 | return loss 40 | 41 | def validation_step(self, batch, batch_idx): 42 | y_hat, y, loss = self.train_val_step(batch, batch_idx) 43 | acc = self.valid_acc(y_hat, y.long()) 44 | self.log('valid_loss', loss, on_step=True) 45 | return loss 46 | 47 | def test_step(self, batch, batch_idx): 48 | y_hat, y, loss = self.train_val_step(batch, batch_idx) 49 | acc = self.test_acc(y_hat, y.long()) 50 | self.log('test_loss', loss) 51 | return loss 52 | -------------------------------------------------------------------------------- /splearn/nn/models/ConvCA/__init__.py: -------------------------------------------------------------------------------- 1 | from .ConvCA import ConvCA 2 | from .ConvCaLighting import ConvCaLighting -------------------------------------------------------------------------------- /splearn/nn/models/DeepConvNet/__init__.py: -------------------------------------------------------------------------------- 1 | from .DeepConvNet import DeepConvNet -------------------------------------------------------------------------------- /splearn/nn/models/DeepConvNet/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | from torch.nn.functional import elu 6 | 7 | def identity(x): 8 | return x 9 | 10 | def transpose_time_to_spat(x): 11 | """Swap time and spatial dimensions. 12 | Returns 13 | ------- 14 | x: torch.Tensor 15 | tensor in which last and first dimensions are swapped 16 | """ 17 | return x.permute(0, 3, 2, 1) 18 | 19 | def squeeze_final_output(x): 20 | """Removes empty dimension at end and potentially removes empty time 21 | dimension. It does not just use squeeze as we never want to remove 22 | first dimension. 23 | Returns 24 | ------- 25 | x: torch.Tensor 26 | squeezed tensor 27 | """ 28 | 29 | assert x.size()[3] == 1 30 | x = x[:, :, :, 0] 31 | if x.size()[2] == 1: 32 | x = x[:, :, 0] 33 | return x 34 | 35 | def np_to_th( 36 | X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs 37 | ): 38 | """ 39 | Convenience function to transform numpy array to `torch.Tensor`. 40 | Converts `X` to ndarray using asarray if necessary. 41 | Parameters 42 | ---------- 43 | X: ndarray or list or number 44 | Input arrays 45 | requires_grad: bool 46 | passed on to Variable constructor 47 | dtype: numpy dtype, optional 48 | var_kwargs: 49 | passed on to Variable constructor 50 | Returns 51 | ------- 52 | var: `torch.Tensor` 53 | """ 54 | if not hasattr(X, "__len__"): 55 | X = [X] 56 | X = np.asarray(X) 57 | if dtype is not None: 58 | X = X.astype(dtype) 59 | X_tensor = torch.tensor(X, requires_grad=requires_grad, **tensor_kwargs) 60 | if pin_memory: 61 | X_tensor = X_tensor.pin_memory() 62 | return X_tensor 63 | 64 | class Ensure4d(nn.Module): 65 | def forward(self, x): 66 | while(len(x.shape) < 4): 67 | x = x.unsqueeze(-1) 68 | return x 69 | 70 | class Expression(nn.Module): 71 | """Compute given expression on forward pass. 72 | Parameters 73 | ---------- 74 | expression_fn : callable 75 | Should accept variable number of objects of type 76 | `torch.autograd.Variable` to compute its output. 77 | """ 78 | 79 | def __init__(self, expression_fn): 80 | super(Expression, self).__init__() 81 | self.expression_fn = expression_fn 82 | 83 | def forward(self, *x): 84 | return self.expression_fn(*x) 85 | 86 | def __repr__(self): 87 | if hasattr(self.expression_fn, "func") and hasattr( 88 | self.expression_fn, "kwargs" 89 | ): 90 | expression_str = "{:s} {:s}".format( 91 | self.expression_fn.func.__name__, str(self.expression_fn.kwargs) 92 | ) 93 | elif hasattr(self.expression_fn, "__name__"): 94 | expression_str = self.expression_fn.__name__ 95 | else: 96 | expression_str = repr(self.expression_fn) 97 | return ( 98 | self.__class__.__name__ + 99 | "(expression=%s) " % expression_str 100 | ) 101 | 102 | class AvgPool2dWithConv(nn.Module): 103 | """ 104 | Compute average pooling using a convolution, to have the dilation parameter. 105 | Parameters 106 | ---------- 107 | kernel_size: (int,int) 108 | Size of the pooling region. 109 | stride: (int,int) 110 | Stride of the pooling operation. 111 | dilation: int or (int,int) 112 | Dilation applied to the pooling filter. 113 | padding: int or (int,int) 114 | Padding applied before the pooling operation. 115 | """ 116 | 117 | def __init__(self, kernel_size, stride, dilation=1, padding=0): 118 | super(AvgPool2dWithConv, self).__init__() 119 | self.kernel_size = kernel_size 120 | self.stride = stride 121 | self.dilation = dilation 122 | self.padding = padding 123 | # don't name them "weights" to 124 | # make sure these are not accidentally used by some procedure 125 | # that initializes parameters or something 126 | self._pool_weights = None 127 | 128 | def forward(self, x): 129 | # Create weights for the convolution on demand: 130 | # size or type of x changed... 131 | in_channels = x.size()[1] 132 | weight_shape = ( 133 | in_channels, 134 | 1, 135 | self.kernel_size[0], 136 | self.kernel_size[1], 137 | ) 138 | if self._pool_weights is None or ( 139 | (tuple(self._pool_weights.size()) != tuple(weight_shape)) or 140 | (self._pool_weights.is_cuda != x.is_cuda) or 141 | (self._pool_weights.data.type() != x.data.type()) 142 | ): 143 | n_pool = np.prod(self.kernel_size) 144 | weights = np_to_th( 145 | np.ones(weight_shape, dtype=np.float32) / float(n_pool) 146 | ) 147 | weights = weights.type_as(x) 148 | if x.is_cuda: 149 | weights = weights.cuda() 150 | self._pool_weights = weights 151 | 152 | pooled = F.conv2d( 153 | x, 154 | self._pool_weights, 155 | bias=None, 156 | stride=self.stride, 157 | dilation=self.dilation, 158 | padding=self.padding, 159 | groups=in_channels, 160 | ) 161 | return pooled -------------------------------------------------------------------------------- /splearn/nn/models/EEGNet/CompactEEGNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """EEGNet: Compact Convolutional Neural Network (Compact-CNN) https://arxiv.org/pdf/1803.04566.pdf 3 | """ 4 | import torch 5 | from torch import nn 6 | from splearn.nn.modules.conv2d import SeparableConv2d 7 | 8 | 9 | class CompactEEGNet(nn.Module): 10 | """ 11 | EEGNet: Compact Convolutional Neural Network (Compact-CNN) 12 | Compact Convolutional Neural Networks for Classification of Asynchronous Steady-state Visual Evoked Potentials 13 | https://arxiv.org/pdf/1803.04566.pdf 14 | """ 15 | def __init__(self, num_channel=10, num_classes=4, signal_length=1000, f1=96, f2=96, d=1): 16 | super().__init__() 17 | 18 | self.signal_length = signal_length 19 | 20 | # layer 1 21 | self.conv1 = nn.Conv2d(1, f1, (1, signal_length), padding=(0,signal_length//2)) 22 | self.bn1 = nn.BatchNorm2d(f1) 23 | self.depthwise_conv = nn.Conv2d(f1, d*f1, (num_channel, 1), groups=f1) 24 | self.bn2 = nn.BatchNorm2d(d*f1) 25 | self.avgpool1 = nn.AvgPool2d((1,4)) 26 | 27 | # layer 2 28 | self.separable_conv = SeparableConv2d( 29 | in_channels=f1, 30 | out_channels=f2, 31 | kernel_size=(1,16) 32 | ) 33 | self.bn3 = nn.BatchNorm2d(f2) 34 | self.avgpool2 = nn.AvgPool2d((1,8)) 35 | 36 | # layer 3 37 | self.fc = nn.Linear(in_features=f2*(signal_length//32), out_features=num_classes) 38 | 39 | self.dropout = nn.Dropout(p=0.5) 40 | self.elu = nn.ELU() 41 | 42 | def forward(self, x): 43 | 44 | # layer 1 45 | x = torch.unsqueeze(x,1) 46 | x = self.conv1(x) 47 | x = self.bn1(x) 48 | x = self.depthwise_conv(x) 49 | x = self.bn2(x) 50 | x = self.elu(x) 51 | x = self.avgpool1(x) 52 | x = self.dropout(x) 53 | 54 | # layer 2 55 | x = self.separable_conv(x) 56 | x = self.bn3(x) 57 | x = self.elu(x) 58 | x = self.avgpool2(x) 59 | x = self.dropout(x) 60 | 61 | # layer 3 62 | x = torch.flatten(x, start_dim=1) 63 | x = self.fc(x) 64 | 65 | return x -------------------------------------------------------------------------------- /splearn/nn/models/EEGNet/__init__.py: -------------------------------------------------------------------------------- 1 | from splearn.nn.models.EEGNet import CompactEEGNet 2 | -------------------------------------------------------------------------------- /splearn/nn/models/Multitask/Multitask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class MultitaskSSVEP(nn.Module): 6 | """ 7 | Using multi-task learning to capture signals simultaneously from the fovea efficiently and the neighboring targets in the peripheral vision generate a visual response map. A calibration-free user-independent solution, desirable for clinical diagnostics. A stepping stone for an objective assessment of glaucoma patients’ visual field. 8 | Learn more about this model at https://jinglescode.github.io/ssvep-multi-task-learning/ 9 | This model is a multi-label model. Although it produces multiple outputs, we also used this model to get our multi-class results in our paper. 10 | 11 | Usage: 12 | model = MultitaskSSVEP( 13 | num_channel=11, 14 | num_classes=40, 15 | signal_length=250, 16 | ) 17 | x = torch.randn(2, 11, 250) 18 | print("Input shape:", x.shape) # torch.Size([2, 11, 250]) 19 | y = model(x) 20 | print("Output shape:", y.shape) # torch.Size([2, 40, 2]) 21 | 22 | Cite: 23 | @inproceedings{khok2020deep, 24 | title={Deep Multi-Task Learning for SSVEP Detection and Visual Response Mapping}, 25 | author={Khok, Hong Jing and Koh, Victor Teck Chang and Guan, Cuntai}, 26 | booktitle={2020 IEEE International Conference on Systems, Man, and Cybernetics (SMC)}, 27 | pages={1280--1285}, 28 | year={2020}, 29 | organization={IEEE} 30 | } 31 | """ 32 | 33 | def __init__(self, num_channel=10, num_classes=4, signal_length=1000, filters_n1=4, kernel_window_ssvep=59, kernel_window=19, conv_3_dilation=4, conv_4_dilation=4): 34 | super().__init__() 35 | 36 | filters = [filters_n1, filters_n1 * 2] 37 | 38 | self.conv_1 = Conv2dBlockELU(in_ch=1, out_ch=filters[0], kernel_size=(1, kernel_window_ssvep), w_in=signal_length) 39 | self.conv_2 = Conv2dBlockELU(in_ch=filters[0], out_ch=filters[0], kernel_size=(num_channel, 1)) 40 | self.conv_3 = Conv2dBlockELU(in_ch=filters[0], out_ch=filters[1], kernel_size=(1, kernel_window), padding=(0,conv_3_dilation-1), dilation=(1,conv_3_dilation), w_in=self.conv_1.w_out) 41 | self.conv_4 = Conv2dBlockELU(in_ch=filters[1], out_ch=filters[1], kernel_size=(1, kernel_window), padding=(0,conv_4_dilation-1), dilation=(1,conv_4_dilation), w_in=self.conv_3.w_out) 42 | self.conv_mtl = multitask_block(filters[1]*num_classes, num_classes, kernel_size=(1, self.conv_4.w_out)) 43 | 44 | self.dropout = nn.Dropout(p=0.5) 45 | 46 | def forward(self, x): 47 | x = torch.unsqueeze(x,1) 48 | 49 | x = self.conv_1(x) 50 | x = self.conv_2(x) 51 | x = self.dropout(x) 52 | 53 | x = self.conv_3(x) 54 | x = self.conv_4(x) 55 | x = self.dropout(x) 56 | 57 | x = self.conv_mtl(x) 58 | return x 59 | 60 | 61 | class Conv2dBlockELU(nn.Module): 62 | def __init__(self, in_ch, out_ch, kernel_size, padding=(0,0), dilation=(1,1), groups=1, w_in=None): 63 | super(Conv2dBlockELU, self).__init__() 64 | self.conv = nn.Sequential( 65 | nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding, dilation=dilation, groups=groups), 66 | nn.BatchNorm2d(out_ch), 67 | nn.ELU(inplace=True) 68 | ) 69 | 70 | if w_in is not None: 71 | self.w_out = int( ((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1]-1)-1) / 1) + 1 ) 72 | 73 | def forward(self, x): 74 | return self.conv(x) 75 | 76 | 77 | class multitask_block(nn.Module): 78 | def __init__(self, in_ch, num_classes, kernel_size): 79 | super(multitask_block, self).__init__() 80 | self.num_classes = num_classes 81 | self.conv_mtl = nn.Conv2d(in_ch, num_classes*2, kernel_size=kernel_size, groups=num_classes) 82 | 83 | def forward(self, x): 84 | x = torch.cat(self.num_classes*[x], 1) 85 | x = self.conv_mtl(x) 86 | x = x.squeeze() 87 | x = x.view(-1, self.num_classes, 2) 88 | return x 89 | 90 | 91 | def test(): 92 | model = MultitaskSSVEP( 93 | num_channel=11, 94 | num_classes=40, 95 | signal_length=250, 96 | ) 97 | 98 | x = torch.randn(2, 11, 250) 99 | print("Input shape:", x.shape) # torch.Size([2, 11, 250]) 100 | y = model(x) 101 | print("Output shape:", y.shape) # torch.Size([2, 40, 2]) 102 | 103 | if __name__ == "__main__": 104 | test() -------------------------------------------------------------------------------- /splearn/nn/models/Multitask/MultitaskClassifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .Multitask import MultitaskSSVEP 4 | 5 | class MultitaskSSVEPClassifier(nn.Module): 6 | """ 7 | Using multi-task learning to capture signals simultaneously from the fovea efficiently and the neighboring targets in the peripheral vision generate a visual response map. A calibration-free user-independent solution, desirable for clinical diagnostics. A stepping stone for an objective assessment of glaucoma patients’ visual field. 8 | Learn more about this model at https://jinglescode.github.io/ssvep-multi-task-learning/ 9 | This model is a multi-class classifier. 10 | Usage: 11 | model = MultitaskSSVEPClassifier( 12 | num_channel=11, 13 | num_classes=40, 14 | signal_length=250, 15 | ) 16 | x = torch.randn(2, 11, 250) 17 | print("Input shape:", x.shape) # torch.Size([2, 11, 250]) 18 | y = model(x) 19 | print("Output shape:", y.shape) # torch.Size([2, 40]) 20 | 21 | Cite: 22 | @inproceedings{khok2020deep, 23 | title={Deep Multi-Task Learning for SSVEP Detection and Visual Response Mapping}, 24 | author={Khok, Hong Jing and Koh, Victor Teck Chang and Guan, Cuntai}, 25 | booktitle={2020 IEEE International Conference on Systems, Man, and Cybernetics (SMC)}, 26 | pages={1280--1285}, 27 | year={2020}, 28 | organization={IEEE} 29 | } 30 | """ 31 | 32 | def __init__(self, num_channel=10, num_classes=4, signal_length=1000, filters_n1=4, kernel_window_ssvep=59, kernel_window=19, conv_3_dilation=4, conv_4_dilation=4): 33 | super().__init__() 34 | self.base = MultitaskSSVEP(num_channel, num_classes, signal_length, filters_n1, kernel_window_ssvep, kernel_window, conv_3_dilation, conv_4_dilation) 35 | self.fc = nn.Linear(num_classes*2, out_features=num_classes) 36 | 37 | def forward(self, x): 38 | x = self.base(x) 39 | x = torch.flatten(x, start_dim=1) 40 | x = self.fc(x) 41 | return x 42 | 43 | 44 | def test(): 45 | model = MultitaskSSVEPClassifier( 46 | num_channel=11, 47 | num_classes=40, 48 | signal_length=250, 49 | ) 50 | 51 | x = torch.randn(2, 11, 250) 52 | print("Input shape:", x.shape) # torch.Size([2, 11, 250]) 53 | y = model(x) 54 | print("Output shape:", y.shape) # torch.Size([2, 40]) 55 | 56 | if __name__ == "__main__": 57 | test() -------------------------------------------------------------------------------- /splearn/nn/models/Multitask/__init__.py: -------------------------------------------------------------------------------- 1 | from splearn.nn.models.Multitask.Multitask import MultitaskSSVEP 2 | from splearn.nn.models.Multitask.MultitaskClassifier import MultitaskSSVEPClassifier 3 | -------------------------------------------------------------------------------- /splearn/nn/models/TimeDomainBasedCNN/TimeDomainBasedCNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from splearn.nn.modules.conv2d import Conv2d 4 | 5 | 6 | class TimeDomainBasedCNN(nn.Module): 7 | """ 8 | Filter Bank Convolutional Neural Network for Short Time-Window Steady-State Visual Evoked Potential Classification 9 | https://ieeexplore.ieee.org/abstract/document/9632600 10 | """ 11 | def __init__(self, num_classes=4, signal_length=1000): 12 | super().__init__() 13 | 14 | self.conv1 = Conv2d(1, 16, (9, 1), padding=(0,0)) 15 | self.bn1 = nn.BatchNorm2d(16) 16 | 17 | self.conv2 = Conv2d(16, 16, (1, signal_length), padding="SAME", stride=5) 18 | self.bn2 = nn.BatchNorm2d(16) 19 | 20 | self.conv3 = Conv2d(16, 16, (1, 5), padding=(0,0)) 21 | self.bn3 = nn.BatchNorm2d(16) 22 | 23 | k2 = int(signal_length/5)-4 24 | self.conv4 = Conv2d(16, 32, (1, k2), padding=(0,0)) 25 | self.bn4 = nn.BatchNorm2d(32) 26 | 27 | self.fc = nn.Linear(32, out_features=num_classes) 28 | 29 | self.dropout = nn.Dropout(p=0.4) 30 | self.elu = nn.ELU() 31 | 32 | def forward(self, x): 33 | 34 | x = torch.unsqueeze(x,1) 35 | 36 | # the first convolution layer 37 | x = self.conv1(x) 38 | x = self.bn1(x) 39 | x = self.elu(x) 40 | 41 | # the second convolution layer 42 | x = self.dropout(x) 43 | x = self.conv2(x) 44 | x = self.bn2(x) 45 | x = self.elu(x) 46 | 47 | # the third convolution layer 48 | x = self.dropout(x) 49 | x = self.conv3(x) 50 | x = self.bn3(x) 51 | x = self.elu(x) 52 | 53 | # the fourth convolution layer 54 | x = self.dropout(x) 55 | x = self.conv4(x) 56 | x = self.bn4(x) 57 | x = self.elu(x) 58 | # flatten used to reduce the dimension of the features 59 | x = torch.flatten(x, start_dim=1) 60 | x = self.dropout(x) 61 | x = self.fc(x) 62 | return x -------------------------------------------------------------------------------- /splearn/nn/models/TimeDomainBasedCNN/__init__.py: -------------------------------------------------------------------------------- 1 | from splearn.nn.models.TimeDomainBasedCNN import TimeDomainBasedCNN 2 | -------------------------------------------------------------------------------- /splearn/nn/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .EEGNet.CompactEEGNet import CompactEEGNet 2 | from .SimSiam.SimSiam import SimSiam 3 | from .SSLClassifier.SSLClassifier import SSLClassifier 4 | from .TimeDomainBasedCNN.TimeDomainBasedCNN import TimeDomainBasedCNN 5 | from .Multitask import MultitaskSSVEP 6 | from .Multitask import MultitaskSSVEPClassifier 7 | from .ConvCA import ConvCA 8 | from .ConvCA import ConvCaLighting 9 | from .DeepConvNet import DeepConvNet -------------------------------------------------------------------------------- /splearn/nn/modules/conv1d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Common 1D convolutions 3 | """ 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | from typing import Optional 9 | 10 | 11 | class DepthWiseConv1d(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size, padding): 13 | super().__init__() 14 | self.padding = padding 15 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, groups=in_channels, bias=bias) 16 | 17 | def forward(self, x): 18 | x = F.pad(x, self.padding) 19 | return self.conv(x) 20 | 21 | #### 22 | 23 | 24 | class BaseConv1d(nn.Module): 25 | """ Base convolution module. """ 26 | def __init__(self): 27 | super(BaseConv1d, self).__init__() 28 | 29 | def _get_sequence_lengths(self, seq_lengths): 30 | return ( 31 | (seq_lengths + 2 * self.conv.padding[0] 32 | - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1) // self.conv.stride[0] + 1 33 | ) 34 | 35 | def forward(self, *args, **kwargs): 36 | raise NotImplementedError 37 | 38 | 39 | class PointwiseConv1d(BaseConv1d): 40 | r""" 41 | When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution. 42 | This operation often used to match dimensions. 43 | Args: 44 | in_channels (int): Number of channels in the input 45 | out_channels (int): Number of channels produced by the convolution 46 | stride (int, optional): Stride of the convolution. Default: 1 47 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 48 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True 49 | Inputs: inputs 50 | - **inputs** (batch, in_channels, time): Tensor containing input vector 51 | Returns: outputs 52 | - **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution. 53 | """ 54 | def __init__( 55 | self, 56 | in_channels: int, 57 | out_channels: int, 58 | stride: int = 1, 59 | padding: int = 0, 60 | bias: bool = True, 61 | ) -> None: 62 | super(PointwiseConv1d, self).__init__() 63 | self.conv = nn.Conv1d( 64 | in_channels=in_channels, 65 | out_channels=out_channels, 66 | kernel_size=1, 67 | stride=stride, 68 | padding=padding, 69 | bias=bias, 70 | ) 71 | 72 | def forward(self, inputs: Tensor) -> Tensor: 73 | return self.conv(inputs) 74 | 75 | 76 | class DepthwiseConv1d(BaseConv1d): 77 | r""" 78 | When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, 79 | this operation is termed in literature as depthwise convolution. 80 | Args: 81 | in_channels (int): Number of channels in the input 82 | out_channels (int): Number of channels produced by the convolution 83 | kernel_size (int or tuple): Size of the convolving kernel 84 | stride (int, optional): Stride of the convolution. Default: 1 85 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 86 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True 87 | Inputs: inputs 88 | - **inputs** (batch, in_channels, time): Tensor containing input vector 89 | Returns: outputs 90 | - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. 91 | """ 92 | def __init__( 93 | self, 94 | in_channels: int, 95 | out_channels: int, 96 | kernel_size: int, 97 | stride: int = 1, 98 | padding: int = 0, 99 | bias: bool = False, 100 | ) -> None: 101 | super(DepthwiseConv1d, self).__init__() 102 | assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels" 103 | self.conv = nn.Conv1d( 104 | in_channels=in_channels, 105 | out_channels=out_channels, 106 | kernel_size=kernel_size, 107 | groups=in_channels, 108 | stride=stride, 109 | padding=padding, 110 | bias=bias, 111 | ) 112 | 113 | def forward(self, inputs: Tensor, input_lengths: Optional[Tensor] = None) -> Tensor: 114 | if input_lengths is None: 115 | return self.conv(inputs) 116 | else: 117 | return self.conv(inputs), self._get_sequence_lengths(input_lengths) -------------------------------------------------------------------------------- /splearn/nn/modules/functional.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch import Tensor 3 | 4 | 5 | class GLU(nn.Module): 6 | r""" 7 | The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing 8 | in the paper “Language Modeling with Gated Convolutional Networks” 9 | """ 10 | def __init__(self, dim: int) -> None: 11 | super(GLU, self).__init__() 12 | self.dim = dim 13 | 14 | def forward(self, inputs: Tensor) -> Tensor: 15 | outputs, gate = inputs.chunk(2, dim=self.dim) 16 | return outputs * gate.sigmoid() 17 | 18 | 19 | class Swish(nn.Module): 20 | r""" 21 | Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied 22 | to a variety of challenging domains such as Image classification and Machine translation. 23 | """ 24 | 25 | def __init__(self): 26 | super(Swish, self).__init__() 27 | 28 | def forward(self, inputs: Tensor) -> Tensor: 29 | return inputs * inputs.sigmoid() 30 | -------------------------------------------------------------------------------- /splearn/nn/modules/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch import Tensor 5 | 6 | 7 | class PositionalEncoding(nn.Module): 8 | r""" 9 | Positional Encoding proposed in "Attention Is All You Need". 10 | Since transformer contains no recurrence and no convolution, in order for the model to make 11 | use of the order of the sequence, we must add some positional information. 12 | "Attention Is All You Need" use sine and cosine functions of different frequencies: 13 | PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model)) 14 | PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model)) 15 | """ 16 | def __init__(self, d_model: int = 512, max_len: int = 5000) -> None: 17 | super(PositionalEncoding, self).__init__() 18 | pe = torch.zeros(max_len, d_model, requires_grad=False) 19 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 20 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) 21 | pe[:, 0::2] = torch.sin(position * div_term) 22 | pe[:, 1::2] = torch.cos(position * div_term) 23 | pe = pe.unsqueeze(0) 24 | self.register_buffer('pe', pe) 25 | 26 | def forward(self, length: int) -> Tensor: 27 | return self.pe[:, :length] -------------------------------------------------------------------------------- /splearn/nn/modules/relative_multi_head_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | from typing import Optional 7 | 8 | from splearn.nn.modules.wrapper import Linear 9 | 10 | 11 | class RelativeMultiHeadAttention(nn.Module): 12 | r""" 13 | Multi-head attention with relative positional encoding. 14 | This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" 15 | Args: 16 | dim (int): The dimension of model 17 | num_heads (int): The number of attention heads. 18 | dropout_p (float): probability of dropout 19 | Inputs: query, key, value, pos_embedding, mask 20 | - **query** (batch, time, dim): Tensor containing query vector 21 | - **key** (batch, time, dim): Tensor containing key vector 22 | - **value** (batch, time, dim): Tensor containing value vector 23 | - **pos_embedding** (batch, time, dim): Positional embedding tensor 24 | - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked 25 | Returns: 26 | - **outputs**: Tensor produces by relative multi head attention module. 27 | """ 28 | def __init__( 29 | self, 30 | dim: int = 512, 31 | num_heads: int = 16, 32 | dropout_p: float = 0.1, 33 | ) -> None: 34 | super(RelativeMultiHeadAttention, self).__init__() 35 | assert dim % num_heads == 0, "d_model % num_heads should be zero." 36 | 37 | self.dim = dim 38 | self.d_head = int(dim / num_heads) 39 | self.num_heads = num_heads 40 | self.sqrt_dim = math.sqrt(dim) 41 | 42 | self.query_proj = Linear(dim, dim) 43 | self.key_proj = Linear(dim, dim) 44 | self.value_proj = Linear(dim, dim) 45 | self.pos_proj = Linear(dim, dim, bias=False) 46 | 47 | self.dropout = nn.Dropout(p=dropout_p) 48 | self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) 49 | self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) 50 | torch.nn.init.xavier_uniform_(self.u_bias) 51 | torch.nn.init.xavier_uniform_(self.v_bias) 52 | 53 | self.out_proj = Linear(dim, dim) 54 | 55 | def forward( 56 | self, 57 | query: Tensor, 58 | key: Tensor, 59 | value: Tensor, 60 | pos_embedding: Tensor, 61 | mask: Optional[Tensor] = None, 62 | ) -> Tensor: 63 | batch_size = value.size(0) 64 | 65 | query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) 66 | key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) 67 | value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) 68 | pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head) 69 | 70 | content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3)) 71 | pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1)) 72 | pos_score = self._relative_shift(pos_score) 73 | 74 | score = (content_score + pos_score) / self.sqrt_dim 75 | 76 | if mask is not None: 77 | mask = mask.unsqueeze(1) 78 | score.masked_fill_(mask, -1e4) 79 | 80 | attn = F.softmax(score, -1) 81 | attn = self.dropout(attn) 82 | 83 | context = torch.matmul(attn, value).transpose(1, 2) 84 | context = context.contiguous().view(batch_size, -1, self.dim) 85 | 86 | return self.out_proj(context) 87 | 88 | def _relative_shift(self, pos_score: Tensor) -> Tensor: 89 | batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() 90 | zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) 91 | padded_pos_score = torch.cat([zeros, pos_score], dim=-1) 92 | 93 | padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1) 94 | pos_score = padded_pos_score[:, :, 1:].view_as(pos_score) 95 | 96 | return pos_score -------------------------------------------------------------------------------- /splearn/nn/modules/residual_connection_module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch import Tensor 3 | from typing import Optional 4 | 5 | 6 | class ResidualConnectionModule(nn.Module): 7 | r""" 8 | Residual Connection Module. 9 | outputs = (module(inputs) x module_factor + inputs x input_factor) 10 | """ 11 | def __init__( 12 | self, 13 | module: nn.Module, 14 | module_factor: float = 1.0, 15 | input_factor: float = 1.0, 16 | ) -> None: 17 | super(ResidualConnectionModule, self).__init__() 18 | self.module = module 19 | self.module_factor = module_factor 20 | self.input_factor = input_factor 21 | 22 | def forward(self, inputs: Tensor, mask: Optional[Tensor] = None) -> Tensor: 23 | if mask is None: 24 | return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor) 25 | else: 26 | return (self.module(inputs, mask) * self.module_factor) + (inputs * self.input_factor) -------------------------------------------------------------------------------- /splearn/nn/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from itertools import product 3 | 4 | from splearn.utils import Config 5 | 6 | 7 | def get_class_name(obj): 8 | return obj.__class__.__name__ 9 | 10 | def get_backbone_and_fc(backbone): 11 | backbone.output_dim = backbone.fc.in_features 12 | classifier = backbone.fc 13 | backbone.fc = torch.nn.Identity() 14 | return backbone, classifier 15 | 16 | 17 | class HyperParametersTuning(): 18 | ''' 19 | Example usage: 20 | >>> configs = { 21 | >>> 'num_layers': [8,16], 22 | >>> 'dim': [128,256], 23 | >>> 'dropout': [0.5], 24 | >>> } 25 | >>> 26 | >>> all_model_config = HyperParametersTuning(configs) 27 | >>> 28 | >>> for i in range(all_model_config.get_num_configs()): 29 | >>> print(all_model_config.get_config(i)) 30 | ''' 31 | def __init__(self, config): 32 | self.all_model_config = [dict(zip(configs, v)) for v in product(*configs.values())] 33 | 34 | def get_num_configs(self): 35 | return len(self.all_model_config) 36 | 37 | def get_config(self, i, return_config_object=True): 38 | if return_config_object: 39 | config = Config(self.all_model_config[i]) 40 | else: 41 | config = self.all_model_config[i] 42 | return -------------------------------------------------------------------------------- /splearn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import Config 2 | from .logger import Logger -------------------------------------------------------------------------------- /splearn/utils/config.py: -------------------------------------------------------------------------------- 1 | from types import SimpleNamespace 2 | 3 | 4 | class Config(SimpleNamespace): 5 | def __init__(self, dictionary, **kwargs): 6 | super().__init__(**kwargs) 7 | for key, value in dictionary.items(): 8 | if isinstance(value, dict): 9 | self.__setattr__(key, Config(value)) 10 | else: 11 | self.__setattr__(key, value) 12 | 13 | def __getattribute__(self, value): 14 | try: 15 | return super().__getattribute__(value) 16 | except AttributeError: 17 | return None -------------------------------------------------------------------------------- /splearn/utils/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from datetime import datetime 4 | from pathlib import Path 5 | 6 | class Logger(): 7 | def __init__( 8 | self, 9 | log_dir="run_logs", 10 | filename_postfix=None, 11 | ): 12 | # create dir if does not exist 13 | Path(log_dir).mkdir(parents=True, exist_ok=True) 14 | 15 | # get this log path 16 | now = datetime.now() 17 | date_time = now.strftime("%Y_%m_%d-%H_%M_%S") 18 | 19 | filename = date_time+"-"+filename_postfix if filename_postfix is not None else date_time 20 | 21 | self.log_path = os.path.join(log_dir, filename+".txt") 22 | 23 | def write_to_log(self, content, break_line=False): 24 | 25 | content = str(content) 26 | 27 | with open(self.log_path, 'a') as log_file: 28 | tofile = content + "\n" 29 | if break_line: 30 | tofile = "\n" + tofile 31 | 32 | log_file.write(tofile) -------------------------------------------------------------------------------- /tutorials/README.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | We aim to bridge the gap for anyone who are new signal processings to get started, check out the these tutorials to get started on signal processings. 4 | 5 | ### 1. Signal composition (time, sampling rate and frequency) 6 | 7 | In order to begin the signal processing adventure, we need to understand what we are dealing with. In the first tutorial, we will uncover what is a signal, and what it is made up of. We will look at how the sampling rate and frequency can affect a signal. We will also see what happens when we combine multiple signals of different frequencies. 8 | 9 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jinglescode/python-signal-processing/blob/main/tutorials/Signal%20composition%20-%20time%2C%20sampling%20rate%20and%20frequency.ipynb) 10 | 11 | ### 2. Fourier Transform 12 | 13 | In the first tutorial, we learned that combining multiple signals will produce a new signal where all the frequencies are jumbled up. In this tutorial, we will learn about Fourier Transform and how it can take a complex signal and decompose it to the frequencies that made it up. 14 | 15 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jinglescode/python-signal-processing/blob/main/tutorials/Fourier%20Transform.ipynb) 16 | 17 | ### 3. Denoising with mean-smooth filter 18 | 19 | Introduce the running mean filter, we learn to apply the simplest filter to perform denoising, we can remove noise that is normally distributed relative to the signal of interest. We will also understand what are edge effects. 20 | 21 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jinglescode/python-signal-processing/blob/main/tutorials/Denoising%20with%20mean-smooth%20filter.ipynb) 22 | 23 | ### 4. Denoising with Gaussian-smooth filter 24 | 25 | We will look at a slight adaptation of the mean-smooth filter, the Gaussian smoothing filter. This tends to smooth the data to be a bit smoother compared to mean-smooth filter. This does not mean that one is better than the other, it depends on the specific applications. It is important to be aware of different filters type and how to use them. 26 | 27 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jinglescode/python-signal-processing/blob/main/tutorials/Denoising%20with%20Gaussian-smooth%20filter.ipynb) 28 | 29 | ### 5. Canonical correlation analysis 30 | 31 | Canonical correlation analysis (CCA) is applied to analyze the frequency components of a signal. In this tutorials, we use CCA for feature extraction and classification. 32 | 33 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jinglescode/python-signal-processing/blob/main/tutorials/Canonical%20Correlation%20Analysis.ipynb) 34 | 35 | ### 6. Task-related component analysis 36 | 37 | Task-related component analysis (TRCA) is a classification method originally for steady-state visual evoked potentials (SSVEPs) detection. 38 | 39 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jinglescode/python-signal-processing/blob/main/tutorials/Task-Related%20Component%20Analysis.ipynb) 40 | --------------------------------------------------------------------------------