├── .gitignore ├── README.md ├── Results.ipynb ├── assets ├── esp-both.png ├── esp-top.png └── final-all.png ├── eeg_lib ├── __init__.py ├── cca.py ├── filtering.py ├── freq_analysis.py ├── logging_server.py ├── plotting.py ├── styles.mplstyle ├── synthetic.py ├── trca.py └── utils.py ├── esp32-cmake.sh ├── experimentation ├── EEG-Analysis.ipynb ├── EEG-DAQ.ipynb └── Eigenvalue Algorithms.ipynb ├── micropython ├── README.md ├── boot.py ├── cross-compile.sh ├── lib │ ├── __init__.py │ ├── aws │ │ ├── aws_ca.pem │ │ ├── dab0ac2b5c-certificate.pem.crt │ │ ├── dab0ac2b5c-private.pem.key │ │ └── dab0ac2b5c-public.pem.key │ ├── computation.py │ ├── config.py │ ├── decoding.py │ ├── diagnostics.py │ ├── logging.py │ ├── networking.py │ ├── peripherals.py │ ├── requests.py │ ├── runner.py │ ├── scheduling.py │ ├── signal.py │ ├── synthetic.py │ ├── umqtt.py │ ├── utils.py │ ├── websocket │ │ ├── Multiserver │ │ │ └── ws_multiserver.py │ │ ├── __init__.py │ │ ├── test.html │ │ ├── websocket_demo.py │ │ ├── ws_connection.py │ │ └── ws_server.py │ └── websockets.py ├── main.py ├── mpy-esp32-algo-development.ipynb ├── mpy-esp32-networking.ipynb ├── mpy-esp32.ipynb └── mpy-modules │ └── .gitkeep ├── misc └── esp32-fft │ ├── esp32-fft.ino │ ├── fft.c │ ├── fft.h │ └── svd.c ├── requirements.txt └── ui └── ssvep_squares.html /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # micropyhon compiled bytecode 8 | *.mpy 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | *.npz 125 | /plots 126 | /data 127 | micropython/mpy-modules/*.mpy 128 | .DS_store 129 | eeg_lib/logs/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Development of an Ultra Low-Cost SSVEP-based BCI Device for Real-Time On-Device Decoding 3 | See the pre-print [here](https://www.biorxiv.org/content/10.1101/2022.01.29.478203v1). 4 | 5 | A poster presentation for original MSc project can be found [here](https://jamestev.github.io/msc-dissertation-poster.pdf) and the full MSc report [here](https://jamestev.github.io/msc-dissertation.pdf). 6 | 7 | ## Setup 8 | See [MicroPython setup](/micropython/README.md) for installing, building and flashing a port of MicroPython for the Espressif ESP32 used in this project. Instructions are also provided for development and experimentation. 9 | ## Background 10 | 11 | This project formed part of an MSc dissertation in collaboration with the [Next Generation Neural Interfaces (NGNI) Lab](https://www.imperial.ac.uk/next-generation-neural-interfaces) at Imperial College London. It was intended to explore the possibility of simultaneous decoding and visualisation of EEG signals acquired from ~100 different audience members simultaneously during a large scale exhibition. Signals acquired from the BCI devices are used for collaborative control in a multiplayer game (using only mental control). 12 | 13 | The cost of existing BCI technologies makes them inaccessible to the general public and prohibits their use on a mass scale. This project aimed to create a __novel, ultra low-cost BCI prototype__ that can change this in the hope to increase public engagement and facilite education in the field of neurotechnologies. 14 | 15 | ## Objectives 16 | The core focus of this study is to develop real time decoding and communication of raw EEG signals acquired from a proprietary EEG hardware device developed by the NGNI Lab. 17 | 18 | ### Constraints 19 | - very tight budget of ~ £20 per device 20 | - all processing related to sampling, signal processing, 21 | decoding and networking must be performed on-device 22 | - use the NGNI hardware prototype based on the Espressif 23 | ESP32 SoC (Tensilica Xtensa LX6 MCU) 24 | - real-time decoding and communication to an AWS cloud 25 | service 26 | - non-invasive BCI using only ‘dry’ surface electrodes 27 | 28 | ## Design 29 | 30 | ### SSVEPs for BCI control 31 | The core role of a BCI is to interpret intentions of a user by making sense of their brain signals. Steady state visual evoked potentials (SSVEPs) are modulations in the brain’s visual cortex in response to a visual stimulus which can be measured as sinusoids at the frequency of the visual stimulus being observed. Visual stimuli usually take the form of shapes that flicker at predetermined frequencies. 32 | 33 | A very simple SSVEP interface with flickering squares is provided in `ui/ssvep_squares.html`. Simply open it with your browser to try it out. You can use the url query parameters to adjust flicker frequencies. For example, by modifying the url in the browser to something like `?up=10&right=12` would set the upper square to flicker at 10Hz and the right at 12Hz. Note that these frequencies are approximate and depend largely on the browser you're using and the load on your machine. 34 | 35 | ### SSVEP decoding 36 | The EEG literature widely reports that multivariate statistical techniques - such as canonical correlation analysis (CCA) and its extensions - are optimal for SSVEP decoding. This project primarily explored two extensions: Multi-setCCA (MsetCCA) [1] and Generalised CCA (GCCA) [2]. An implementation of the seminal TRCA algorithm from [3] is also investigated. 37 | 38 | ## Performance 39 | Results showed that the MsetCCA was most effective and could achieve accuracy and ITR rates comparable with BCIs in the literature. Two promising parameter combinations: 40 | - 4 calibration trials of 0.75s each: accuracy of 95.56 ± 3.74% with ITR of 102 bits/min 41 | - 2 calibration trials of 1s each: accuracy of 80.56 ± 4.46% with ITR of 40 bits/min 42 | 43 | See the [full report](https://jamestev.github.io/msc-dissertation.pdf) for more detailed explanations, results and analysis. 44 | ## References 45 | [1] Y. Zhang, G. Zhou, J. Jin, X. Wang, and A. Cichocki, “Frequency recognition in ssvep-based bci using multiset canonical correlation analysis,” International journal of neural systems, vol. 24, no. 04, p. 1 450 013, 2014. 46 | 47 | [2] Q. Sun, M. Chen, L. Zhang, X. Yuan, and C. Li, “Improving ssvep identification accuracy via generalized canonical correlation analysis,” in 2021 10th International IEEE/EMBS Conference on Neural Engineering (NER), IEEE, 2021, pp. 61–64. 48 | 49 | [3] H. Tanaka, T. Katura, and H. Sato, “Task-related component analysis for functional neuroimaging and application to near-infrared spectroscopy data,” NeuroImage, vol. 64, pp. 308–327, 2013. 50 | -------------------------------------------------------------------------------- /assets/esp-both.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JamesTev/EEG-decoding/c4054dd9d1eac857aedd487a34f177c97d95c0af/assets/esp-both.png -------------------------------------------------------------------------------- /assets/esp-top.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JamesTev/EEG-decoding/c4054dd9d1eac857aedd487a34f177c97d95c0af/assets/esp-top.png -------------------------------------------------------------------------------- /assets/final-all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JamesTev/EEG-decoding/c4054dd9d1eac857aedd487a34f177c97d95c0af/assets/final-all.png -------------------------------------------------------------------------------- /eeg_lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JamesTev/EEG-decoding/c4054dd9d1eac857aedd487a34f177c97d95c0af/eeg_lib/__init__.py -------------------------------------------------------------------------------- /eeg_lib/cca.py: -------------------------------------------------------------------------------- 1 | from sklearn.cross_decomposition import CCA as CCA_sklearn 2 | from .filtering import filterbank 3 | import numpy as np 4 | 5 | from scipy.stats import pearsonr 6 | from scipy.linalg import block_diag 7 | 8 | from .utils import resample, standardise, solve_gen_eig_prob 9 | 10 | np.random.seed(0) 11 | 12 | 13 | class GCCA: 14 | """ 15 | Generalised canonical component analysis. 16 | 17 | Ref: 'Improving SSVEP Identification Accuracy via Generalized Canonical Correlation Analysis' 18 | Sun, Chen et al 19 | """ 20 | 21 | def __init__(self, f_ssvep, fs, Nh=3, w=None, name=None): 22 | self.Nc, self.Ns, self.Nt = None, None, None 23 | self.Nh = Nh 24 | self.w_chi_bar_n = None 25 | self.w_Y_n = None 26 | self.w_Chi_n = None 27 | self.fs = fs 28 | self.f_ssvep = f_ssvep 29 | 30 | self.name = name or "gcca_{0}hz".format(f_ssvep) 31 | 32 | def fit(self, X): 33 | """ 34 | Fit against training tensor X. 35 | 36 | X should be a 3rd order tensor of dim (Nc x Ns x Nt) 37 | """ 38 | assert len(X.shape) == 3, "Expected 4th order input data tensor: Nc x Ns x Nt" 39 | self.Nc, self.Ns, self.Nt = X.shape 40 | 41 | Chi_n = X 42 | Chi_n_c = Chi_n.reshape((self.Nc, self.Ns * self.Nt)) 43 | 44 | Chi_bar_n = np.mean( 45 | Chi_n, axis=-1 46 | ) # mean over trials for each channel with all samples: output shape is Nc x Ns x 1 47 | Chi_bar_n_c = np.concatenate( 48 | [Chi_bar_n for i in range(self.Nt)], axis=1 49 | ) # concat along columns 50 | 51 | Y_n = cca_reference([self.f_ssvep], self.fs, self.Ns, Nh=self.Nh).reshape( 52 | -1, self.Ns 53 | ) 54 | Y_n_c = np.concatenate([Y_n for i in range(self.Nt)], axis=1) 55 | 56 | # form X and D and find eigenvals 57 | X = np.c_[Chi_n_c.T, Chi_bar_n_c.T, Y_n_c.T].T 58 | 59 | d1 = Chi_n_c.dot(Chi_n_c.T) 60 | d2 = Chi_bar_n_c.dot(Chi_bar_n_c.T) 61 | d3 = Y_n_c.dot(Y_n_c.T) 62 | D = block_diag(d1, d2, d3) 63 | 64 | lam, W_eig = solve_gen_eig_prob( 65 | X.dot(X.T), D 66 | ) # solve generalised eigenvalue problem 67 | 68 | i = np.argmax(np.real(lam)) 69 | w = W_eig[:, i] # optimal spatial filter vector with dim (2*Nc + 2*Nh) 70 | 71 | w_Chi_n = w[: self.Nc] # first Nc weight values correspond to data channels 72 | w_Chi_bar_n = w[ 73 | self.Nc : 2 * self.Nc 74 | ] # second Nc weights correspond to Nc template channels 75 | w_Y_n = w[ 76 | 2 * self.Nc : 77 | ] # final 2*Nh weights correspond to ref sinusoids with harmonics 78 | 79 | self.w_chi_bar_n = w_Chi_bar_n.T.dot(Chi_bar_n) 80 | self.w_Y_n = w_Y_n.T.dot(Y_n) 81 | self.w_Chi_n = w_Chi_n 82 | 83 | def classify(self, X_test): 84 | if self.w_chi_bar_n is None: 85 | raise ValueError("call `.fit(X_train)` before performing classification.") 86 | 87 | rho1 = pearsonr(self.w_Chi_n.T.dot(X_test), self.w_chi_bar_n)[0] 88 | rho2 = pearsonr(self.w_Chi_n.T.dot(X_test), self.w_Y_n)[0] 89 | 90 | rho = np.sum([np.sign(rho_i) * rho_i ** 2 for rho_i in [rho1, rho2]]) 91 | 92 | return rho 93 | 94 | 95 | class GCCA_SSVEP: 96 | """ 97 | Generalised canonical component analysis. 98 | 99 | Ref: 'Improving SSVEP Identification Accuracy via Generalized Canonical Correlation Analysis' 100 | Sun, Chen et al 101 | """ 102 | 103 | def __init__(self, stim_freqs, fs, Nh=3, W=None): 104 | self.Nf, self.Nc, self.Ns, self.Nt = None, None, None, None 105 | self.Nh = Nh 106 | self.W = W 107 | self.stim_freqs = stim_freqs 108 | self.fs = fs 109 | 110 | def fit(self, X): 111 | """ 112 | Fit against training tensor X. 113 | 114 | X should be a 4th order tensor of dim (Nf x Nc x Ns x Nt) 115 | """ 116 | assert ( 117 | len(X.shape) == 4 118 | ), "Expected 4th order input data tensor: Nf x Nc x Ns x Nt" 119 | self.Chi = X 120 | self.Nf, self.Nc, self.Ns, self.Nt = X.shape 121 | 122 | W = [] 123 | self.Chi_bar = [] 124 | self.Y = [] 125 | for n in range(len(self.stim_freqs)): 126 | Chi_n = self.Chi[n, :, :, :] 127 | Chi_n_c = Chi_n.reshape((self.Nc, self.Ns * self.Nt)) 128 | 129 | Chi_bar_n = np.mean( 130 | Chi_n, axis=-1 131 | ) # mean over trials for each channel with all samples: output shape is Nc x Ns x 1 132 | self.Chi_bar.append(Chi_bar_n) 133 | Chi_bar_n_c = np.concatenate( 134 | [Chi_bar_n for i in range(self.Nt)], axis=1 135 | ) # concat along columns 136 | 137 | Y_n = cca_reference( 138 | [self.stim_freqs[n]], self.fs, self.Ns, Nh=self.Nh 139 | ).reshape(-1, self.Ns) 140 | self.Y.append(Y_n) 141 | Y_n_c = np.concatenate([Y_n for i in range(self.Nt)], axis=1) 142 | 143 | # form X and D and find eigenvals 144 | X = np.c_[Chi_n_c.T, Chi_bar_n_c.T, Y_n_c.T].T 145 | 146 | d1 = Chi_n_c.dot(Chi_n_c.T) 147 | d2 = Chi_bar_n_c.dot(Chi_bar_n_c.T) 148 | d3 = Y_n_c.dot(Y_n_c.T) 149 | D = block_diag(d1, d2, d3) 150 | 151 | lam, W_eig = solve_gen_eig_prob( 152 | X.dot(X.T), D 153 | ) # solve generalised eigenvalue problem 154 | 155 | i = np.argmax(np.real(lam)) 156 | W.append( 157 | W_eig[:, i] 158 | ) # optimal spatial filter vector with dim (2*Nc + 2*Nh) 159 | 160 | self.Chi_bar = np.array(self.Chi_bar) # form tensors 161 | self.Y = np.array(self.Y) 162 | self.W = np.array(W) 163 | 164 | def classify(self, X_test): 165 | if self.W is None: 166 | raise ValueError( 167 | "w must be computed using `compute_w` before performing classification." 168 | ) 169 | result = {f: 0 for f in self.stim_freqs} 170 | 171 | for i in range(len(self.stim_freqs)): 172 | Chi_bar_n = self.Chi_bar[i, :, :] 173 | Y_n = self.Y[i, :, :] 174 | 175 | w = self.W[i, :] 176 | w_Chi_n = w[: self.Nc] # first Nc weight values correspond to data channels 177 | w_Chi_bar_n = w[ 178 | self.Nc : 2 * self.Nc 179 | ] # second Nc weights correspond to Nc template channels 180 | w_Y_n = w[ 181 | 2 * self.Nc : 182 | ] # final 2*Nh weights correspond to ref sinusoids with harmonics 183 | 184 | rho1 = pearsonr(w_Chi_n.T.dot(X_test), w_Chi_bar_n.T.dot(Chi_bar_n))[0] 185 | rho2 = pearsonr(w_Chi_n.T.dot(X_test), w_Y_n.T.dot(Y_n))[0] 186 | 187 | rho_n = sum([np.sign(rho_i) * rho_i ** 2 for rho_i in [rho1, rho2]]) 188 | result[self.stim_freqs[i]] = rho_n 189 | return result 190 | 191 | 192 | class MsetCCA: 193 | """ 194 | Multi set CCA 195 | 196 | Ref: FREQUENCY RECOGNITION IN SSVEP-BASED BCI USING MULTISET CANONICAL CORRELATION ANALYSIS, Zhang, Zhou et al 197 | """ 198 | 199 | def __init__(self): 200 | self.Nc, self.Ns, self.Nt = None, None, None 201 | 202 | def fit(self, chi): 203 | """ 204 | Fit against training tensor chi. 205 | 206 | chi should be a 3rd order tensor of dim (Nc x Ns x Nt) 207 | """ 208 | assert ( 209 | len(chi.shape) == 3 210 | ), "Expected 3rd order input data tensor for freq. fm: Nc x Ns x Nt" 211 | 212 | Nc, Ns, Nt = chi.shape 213 | 214 | chi_c = np.vstack([chi[:, :, i] for i in range(Nt)]) 215 | R = chi_c.dot(chi_c.T) 216 | 217 | # form inra-trial covariance matrix S 218 | blocks = [chi[:, :, i].dot(chi[:, :, i].T) for i in range(Nt)] 219 | S = block_diag(*blocks) 220 | 221 | lam, V = solve_gen_eig_prob((R - S), S) # solve generalise eig value problem 222 | w = V[:, np.argmax(lam)].reshape( 223 | (Nt, Nc) 224 | ) # sort by largest eig vals in lam vector. TODO: check reshaping 225 | 226 | self.Y = np.array( 227 | [w[i, :].T.dot(chi[:, :, i]) for i in range(Nt)] 228 | ) # form optimised reference matrix 229 | self.Nc, self.Ns, self.Nt = Nc, Ns, Nt 230 | 231 | def compute_corr(self, X_test, method="cca"): 232 | if self.Y is None: 233 | raise ValueError( 234 | "Reference matrix Y must be computed using `fit` before computing corr" 235 | ) 236 | if method == "eig": 237 | rho = CCA.cca_eig(X_test.T, self.Y.T)[0] 238 | else: # use sklearn implementation 239 | cca = CCA_sklearn(n_components=1) 240 | Xc, Yc = cca.fit_transform(X_test.T, self.Y.T) 241 | rho = pearsonr(Xc[:, 0], Yc[:, 0])[0] 242 | return rho 243 | 244 | 245 | class MsetCCA_SSVEP: 246 | def __init__(self, stim_freqs): 247 | self.stim_freqs = stim_freqs 248 | self.models = { 249 | f: MsetCCA() for f in stim_freqs 250 | } # init independent TRCA models per stim freq 251 | 252 | def fit(self, X_ssvep): 253 | """ 254 | Fit the independent Nf TRCA models using input data tensor `X_ssvep` 255 | 256 | :param 257 | X_ssvep: 4th order data tensor (Nf x Nc x Ns x Nt) 258 | """ 259 | assert ( 260 | len(X_ssvep.shape) == 4 261 | ), "Expected a 4th order data tensor with shape (Nf x Nc x Ns x Nt)" 262 | assert ( 263 | len(self.stim_freqs) == X_ssvep.shape[0] 264 | ), "Length of supplied stim freqs does not match first dimension of input data" 265 | 266 | for i, f in enumerate(self.stim_freqs): 267 | self.models[f].fit(X_ssvep[i, :, :, :]) 268 | 269 | def classify(self, X_test, method="cca"): 270 | assert len(X_test.shape) == 2, "Expected a matrix with shape (Nc x Ns)" 271 | 272 | return { 273 | f: self.models[f].compute_corr(X_test, method=method) 274 | for f in self.stim_freqs 275 | } 276 | 277 | def get_eig(self): 278 | return {f: self.models[f].get_eig() for f in self.stim_freqs} 279 | 280 | 281 | class CCA: 282 | def __init__(self, stim_freqs, fs, Nh=3): 283 | self.Nh = Nh 284 | self.stim_freqs = stim_freqs 285 | self.fs = fs 286 | self.cca_models = {f: CCA_sklearn(n_components=1) for f in stim_freqs} 287 | self.is_fit = False 288 | 289 | def fit(self, X, resampling_factor=None): 290 | for f in self.stim_freqs: 291 | Y = cca_reference([f], self.fs, len(X), Nh=self.Nh, standardise_out=True) 292 | 293 | if resampling_factor is not None: 294 | X = resample(X, resampling_factor) 295 | 296 | self.cca_models[f].fit(X, Y) 297 | 298 | self.is_fit = True 299 | 300 | def classify(self, X_test, method="eig"): 301 | if not self.is_fit and method != "eig": 302 | self.fit(X_test) 303 | 304 | result = {} 305 | Cxx = np.dot( 306 | X_test, X_test.transpose() 307 | ) # precompute data auto correlation matrix 308 | for f in self.stim_freqs: 309 | Y = cca_reference( 310 | [f], self.fs, len(X_test), Nh=self.Nh, standardise_out=True 311 | ) 312 | if method == "eig": 313 | rho = self.cca_eig(X_test, Y)[0] 314 | else: 315 | Xc, Yc = self.cca_models[f].transform( 316 | X_test, Y 317 | ) # canonical variable matrices. Xc = X^T.W_x 318 | rho = pearsonr(Xc[:, 0], Yc[:, 0])[0] 319 | result[f] = rho 320 | return result 321 | 322 | @staticmethod 323 | def cca_eig(X, Y, n_components=1): 324 | Cxx = X.T.dot(X) # auto correlation matrix 325 | Cyy = Y.T.dot(Y) 326 | Cxy = X.T.dot(Y) # cross correlation matrix 327 | Cyx = Y.T.dot(X) # same as Cxy.T 328 | 329 | M1 = np.linalg.inv(Cxx).dot(Cxy) # intermediate result 330 | M2 = np.linalg.inv(Cyy).dot(Cyx) 331 | 332 | M = M1.dot(M2) 333 | lam = np.linalg.eigvals(M) 334 | return sorted(np.sqrt(lam), reverse=True)[ 335 | :n_components 336 | ] # return largest n sqrt eig vals 337 | 338 | 339 | def fbcca(eeg, list_freqs, fs, num_harms=3, num_fbs=5): 340 | 341 | """ 342 | Steady-state visual evoked potentials (SSVEPs) detection using the filter 343 | bank canonical correlation analysis (FBCCA)-based method [1]. 344 | function results = test_fbcca(eeg, list_freqs, fs, num_harms, num_fbs) 345 | Input: 346 | eeg : Input eeg data 347 | (# of targets, # of channels, Data length [sample]) 348 | list_freqs : List for stimulus frequencies 349 | fs : Sampling frequency 350 | num_harms : # of harmonics 351 | num_fbs : # of filters in filterbank analysis 352 | Output: 353 | results : The target estimated by this method 354 | Reference: 355 | [1] X. Chen, Y. Wang, S. Gao, T. -P. Jung and X. Gao, 356 | "Filter bank canonical correlation analysis for implementing a 357 | high-speed SSVEP-based brain-computer interface", 358 | J. Neural Eng., vol.12, 046008, 2015. 359 | """ 360 | 361 | fb_coefs = np.power(np.arange(1, num_fbs + 1), (-1.25)) + 0.25 362 | 363 | ( 364 | num_targs, 365 | _, 366 | num_smpls, 367 | ) = eeg.shape # 40 taget (means 40 fre-phase combination that we want to predict) 368 | y_ref = cca_reference(list_freqs, fs, num_smpls, num_harms) 369 | cca = CCA(n_components=1) # initilize CCA 370 | 371 | # result matrix 372 | r = np.zeros((num_fbs, num_targs)) 373 | results = np.zeros(num_targs) 374 | 375 | for targ_i in range(num_targs): 376 | test_tmp = np.squeeze(eeg[targ_i, :, :]) # deal with one target a time 377 | for fb_i in range( 378 | num_fbs 379 | ): # filter bank number, deal with different filter bank 380 | testdata = filterbank(test_tmp, fs, fb_i) # data after filtering 381 | for class_i in range(num_targs): 382 | refdata = np.squeeze( 383 | y_ref[class_i, :, :] 384 | ) # pick corresponding freq target reference signal 385 | test_C, ref_C = cca.fit_transform(testdata.T, refdata.T) 386 | # len(row) = len(observation), len(column) = variables of each observation 387 | # number of rows should be the same, so need transpose here 388 | # output is the highest correlation linear combination of two sets 389 | r_tmp, _ = pearsonr( 390 | np.squeeze(test_C), np.squeeze(ref_C) 391 | ) # return r and p_value, use np.squeeze to adapt the API 392 | r[fb_i, class_i] = r_tmp 393 | 394 | rho = np.dot( 395 | fb_coefs, r 396 | ) # weighted sum of r from all different filter banks' result 397 | tau = np.argmax( 398 | rho 399 | ) # get maximum from the target as the final predict (get the index) 400 | results[targ_i] = tau # index indicate the maximum(most possible) target 401 | return results 402 | 403 | 404 | def cca_reference(list_freqs, fs, Ns, Nh=3, standardise_out=False): 405 | 406 | """ 407 | Generate reference signals for the canonical correlation analysis (CCA) 408 | -based steady-state visual evoked potentials (SSVEPs) detection [1, 2]. 409 | function [ y_ref ] = cca_reference(listFreq, fs, Ns, Nh) 410 | Input: 411 | list_freqs : stimulus frequencies 412 | fs : Sampling frequency 413 | Ns : # of samples in trial 414 | Nh : # of harmonics 415 | Output: 416 | y_ref : Generated reference signals with shape (Nf, Ns, 2*Nh) 417 | """ 418 | 419 | num_freqs = len(list_freqs) 420 | tidx = np.arange(1, Ns + 1) / fs # time index 421 | 422 | y_ref = np.zeros((num_freqs, 2 * Nh, Ns)) 423 | for freq_i in range(num_freqs): 424 | tmp = [] 425 | for harm_i in range(1, Nh + 1): 426 | stim_freq = list_freqs[freq_i] # in HZ 427 | # Sin and Cos 428 | tmp.extend( 429 | [ 430 | np.sin(2 * np.pi * tidx * harm_i * stim_freq), 431 | np.cos(2 * np.pi * tidx * harm_i * stim_freq), 432 | ] 433 | ) 434 | y_ref[freq_i] = tmp # 2*num_harms because include both sin and cos 435 | 436 | y_ref = np.squeeze(y_ref) 437 | if standardise_out: # zero mean, unit std. dev 438 | return standardise(y_ref) 439 | return y_ref 440 | 441 | 442 | """ 443 | Base on fbcca, but adapt to our input format 444 | """ 445 | 446 | 447 | def fbcca_realtime(data, list_freqs, fs, num_harms=3, num_fbs=5): 448 | 449 | fb_coefs = np.power(np.arange(1, num_fbs + 1), (-1.25)) + 0.25 450 | 451 | num_targs = len(list_freqs) 452 | _, num_smpls = data.shape 453 | 454 | y_ref = cca_reference(list_freqs, fs, num_smpls, num_harms) 455 | cca = CCA(n_components=1) # initialize CCA 456 | 457 | # result matrix 458 | r = np.zeros((num_fbs, num_targs)) 459 | 460 | for fb_i in range(num_fbs): # filter bank number, deal with different filter bank 461 | testdata = filterbank(data, fs, fb_i) # data after filtering 462 | for class_i in range(num_targs): 463 | refdata = np.squeeze( 464 | y_ref[class_i, :, :] 465 | ) # pick corresponding freq target reference signal 466 | test_C, ref_C = cca.fit_transform(testdata.T, refdata.T) 467 | r_tmp, _ = pearsonr( 468 | np.squeeze(test_C), np.squeeze(ref_C) 469 | ) # return r and p_value 470 | if r_tmp == np.nan: 471 | r_tmp = 0 472 | r[fb_i, class_i] = r_tmp 473 | 474 | rho = np.dot( 475 | fb_coefs, r 476 | ) # weighted sum of r from all different filter banks' result 477 | print(rho) # print out the correlation 478 | result = np.argmax( 479 | rho 480 | ) # get maximum from the target as the final predict (get the index), and index indicates the maximum entry(most possible target) 481 | """ Threshold """ 482 | THRESHOLD = 2.1 483 | if ( 484 | abs(rho[result]) < THRESHOLD 485 | ): # 2.587=np.sum(fb_coefs*0.8) #2.91=np.sum(fb_coefs*0.9) #1.941=np.sum(fb_coefs*0.6) 486 | return 999 # if the correlation isn't big enough, do not return any command 487 | else: 488 | return result 489 | -------------------------------------------------------------------------------- /eeg_lib/filtering.py: -------------------------------------------------------------------------------- 1 | from scipy.signal import filtfilt, iirnotch, freqz, butter, iirfilter 2 | import numpy as np 3 | import warnings 4 | import scipy.signal 5 | 6 | 7 | # Filter requirements. 8 | 9 | 10 | def butterworth_lowpass(x, fc, fs, order): 11 | nyq = 0.5 * fs # Nyquist Frequency 12 | normal_cutoff = fc / nyq 13 | # Get the filter coefficients 14 | b, a = butter(order, normal_cutoff, btype="low", analog=False) 15 | y = filtfilt(b, a, x) 16 | return y 17 | 18 | 19 | def iir_bandpass(x, f0, f1, fs, order=4, ftype="cheby2"): 20 | b, a = iirfilter( 21 | order, [f0, f1], rs=60, btype="band", fs=fs, ftype=ftype, output="ba" 22 | ) 23 | y = filtfilt(b, a, x) 24 | return y 25 | 26 | 27 | def iir_notch(y, f0, fs, Q=30): 28 | w0 = f0 / (fs / 2) 29 | Q = 30 30 | b, a = iirnotch(w0, Q) 31 | 32 | # filter response 33 | w, h = freqz(b, a) 34 | filt_freq = w * fs / (2 * np.pi) 35 | y_filt = filtfilt(b, a, y) 36 | 37 | return y_filt 38 | 39 | 40 | def butter_notch(y, f0, fs, order=2): 41 | w0 = [(f0 - 15) / (fs / 2), (f0 + 15) / (fs / 2)] 42 | b, a = butter(order, w0, btype="bandstop") 43 | w, h = freqz(b, a) 44 | filt_freq = w * fs / (2 * np.pi) 45 | y_filt = filtfilt(b, a, y) 46 | 47 | return y_filt 48 | 49 | 50 | def iir_bandpass_ssvep_tensor(X, f0, f1, fs, order=4, ftype="cheby2"): 51 | """ 52 | Perform IIR bandpass over SSVEP frequency band. 53 | 54 | Expects 4th order data tensor X in form: Nf x Nc X Ns x Nt 55 | """ 56 | X_f = np.zeros_like(X) 57 | Nf, Nc, Ns, Nt = X.shape 58 | 59 | bp_filt = lambda x: iir_bandpass(x, f0, f1, fs, order=order, ftype=ftype) 60 | 61 | for trial in range(Nt): # TODO parallelise or similar 62 | for f_idx in range(Nf): 63 | X_i = X[f_idx, :, :, trial] 64 | X_f[f_idx, :, :, trial] = np.array( 65 | [bp_filt(X_i[chan, :]) for chan in range(Nc)] 66 | ) 67 | return X_f 68 | 69 | 70 | def filterbank(eeg, fs, idx_fb): 71 | if idx_fb == None: 72 | warnings.warn( 73 | "stats:filterbank:MissingInput " 74 | + "Missing filter index. Default value (idx_fb = 0) will be used." 75 | ) 76 | idx_fb = 0 77 | elif idx_fb < 0 or 9 < idx_fb: 78 | raise ValueError( 79 | "stats:filterbank:InvalidInput " 80 | + "The number of sub-bands must be 0 <= idx_fb <= 9." 81 | ) 82 | 83 | if len(eeg.shape) == 2: 84 | num_chans = eeg.shape[0] 85 | num_trials = 1 86 | else: 87 | num_chans, _, num_trials = eeg.shape 88 | 89 | # Nyquist Frequency = Fs/2N 90 | Nq = fs / 2 91 | 92 | passband = [6, 14, 22, 30, 38, 46, 54, 62, 70, 78] 93 | stopband = [4, 10, 16, 24, 32, 40, 48, 56, 64, 72] 94 | Wp = [passband[idx_fb] / Nq, 90 / Nq] 95 | Ws = [stopband[idx_fb] / Nq, 100 / Nq] 96 | [N, Wn] = scipy.signal.cheb1ord( 97 | Wp, Ws, 3, 40 98 | ) # band pass filter StopBand=[Ws(1)~Ws(2)] PassBand=[Wp(1)~Wp(2)] 99 | [B, A] = scipy.signal.cheby1(N, 0.5, Wn, "bandpass") # Wn passband edge frequency 100 | 101 | y = np.zeros(eeg.shape) 102 | if num_trials == 1: 103 | for ch_i in range(num_chans): 104 | # apply filter, zero phass filtering by applying a linear filter twice, once forward and once backwards. 105 | # to match matlab result we need to change padding length 106 | y[ch_i, :] = scipy.signal.filtfilt( 107 | B, A, eeg[ch_i, :], padtype="odd", padlen=3 * (max(len(B), len(A)) - 1) 108 | ) 109 | 110 | else: 111 | for trial_i in range(num_trials): 112 | for ch_i in range(num_chans): 113 | y[ch_i, :, trial_i] = scipy.signal.filtfilt( 114 | B, 115 | A, 116 | eeg[ch_i, :, trial_i], 117 | padtype="odd", 118 | padlen=3 * (max(len(B), len(A)) - 1), 119 | ) 120 | return y 121 | -------------------------------------------------------------------------------- /eeg_lib/freq_analysis.py: -------------------------------------------------------------------------------- 1 | # compute STFT using Welch's method 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from scipy import signal 5 | from scipy.fft import fftshift 6 | 7 | import pandas as pd 8 | from .utils import dB 9 | 10 | 11 | def plot_stft_spectra( 12 | Sxx, 13 | f, 14 | ssvep_f0=None, 15 | figsize=(14, 12), 16 | recursive_av=True, 17 | f_ssvep=[8.75, 10, 12, 15], 18 | ): 19 | def plot_spectrum(ax, f, Sxx, is_db=False, ssvep_f0=None, title=""): 20 | if not is_db: 21 | Sxx = dB(np.abs(Sxx) ** 2) 22 | ax.plot(f, Sxx) 23 | if ssvep_f0 is not None: 24 | ax.axvline(ssvep_f0, c="r", ls=":", lw=1.5, label="expected SSVEP $f_0$") 25 | ax.legend() 26 | ax.set_xlabel("Frequency (Hz)") 27 | ax.set_ylabel("PSD (dB)") 28 | ax.set_title(title) 29 | return ax 30 | 31 | n_win = Sxx.shape[1] 32 | fig, axes = plt.subplots( 33 | n_win, 34 | 3, 35 | figsize=figsize, 36 | ) 37 | for i in range(n_win): 38 | Sxx_av = np.mean(Sxx[:, 0 : i + 1], axis=1) 39 | ax_psd = plot_spectrum( 40 | axes[i][0], f, Sxx[:, i], ssvep_f0=ssvep_f0, title=f"Window {i+1}" 41 | ) 42 | ax_zoom = plot_spectrum( 43 | axes[i][1], f, Sxx[:, i], ssvep_f0=ssvep_f0, title=f"Window {i+1}" 44 | ) 45 | 46 | if recursive_av: 47 | plot_spectrum(ax_psd, f, Sxx_av) 48 | plot_spectrum(ax_zoom, f, Sxx_av) 49 | 50 | ax_zoom.set_xlim( 51 | min(f_ssvep) * 0.75, max(f_ssvep) * 1.25 52 | ) # zoom in on freq band of interest 53 | ax_psd.set_xlim(0, 70) 54 | width = 0.2 55 | # first, decode using indepdendent spectra 56 | f_decoded, p = decode_ssvep(Sxx[:, i], f, f_ssvep) 57 | axes[i][2].bar(x=f_ssvep - width / 2, height=p, width=width) 58 | 59 | # now, decode using recursively-averaged spectra 60 | f_decoded, p = decode_ssvep(Sxx_av, f, f_ssvep) 61 | # labels=[f'$f_{i}$' for i in range(len(p))] 62 | axes[i][2].bar(x=f_ssvep + width / 2, height=p, tick_label=f_ssvep, width=width) 63 | 64 | axes[i][2].legend(["P($f_i$) no av.", "P($f_i$) with av."]) 65 | 66 | fig.tight_layout(pad=1.5) 67 | 68 | 69 | def decode_ssvep(Sxx, f, target_freqs, convert_to_mag=True, verbose=False): 70 | """ 71 | Takes PSD estimate Sxx and freq vector f and finds candidate freq in target_freqs 72 | that corresponds to max power in Sxx. Also returns list of Sxx values at target_freqs. 73 | """ 74 | if convert_to_mag: 75 | Sxx = np.abs(Sxx) ** 2 76 | p = [] # power vector - store PSD at each freq 77 | for fsv in target_freqs: 78 | f_idx = np.searchsorted(f, fsv, side="left") 79 | if f_idx != fsv and verbose: 80 | print( 81 | f"Warning: couldn't find exact match for SSVEP freq {fsv}. Using {f[f_idx]}" 82 | ) 83 | p.append(Sxx[f_idx]) 84 | f_decoded = target_freqs[np.argmax(p)] 85 | return f_decoded, p 86 | 87 | 88 | from scipy.fft import rfft, rfftfreq # real fft 89 | 90 | 91 | def plot_periodogram(x, fs, ssvep_f0=None, N=2048, figsize=(14, 10), axes=None): 92 | 93 | N = min(N, len(x) - 1) 94 | 95 | if not isinstance(x, np.ndarray): 96 | x = x.values 97 | X = rfft(x, n=N) 98 | Pxx_fft = np.abs(X) ** 2 99 | w1 = np.linspace(0, 1, N // 2 + 1) # norm freq (pi rad/sample) 100 | f1 = w1 * fs / 2 # pi rad/sample corresponds to fs/2 101 | 102 | welch_wins = [N // 2] 103 | Pxx_welch_mat = np.zeros((len(welch_wins), N // 2 + 1)) 104 | 105 | for i, win in enumerate(welch_wins): 106 | f_welch, Pxx_welch = signal.welch( 107 | x, fs, nperseg=win, nfft=N 108 | ) # nperseg = welch window len 109 | Pxx_welch_mat[i, :] = dB(Pxx_welch) 110 | 111 | if axes is None: 112 | fig, (ax0, ax1) = plt.subplots(2, 1, figsize=figsize) 113 | else: 114 | ax0, ax1 = axes 115 | 116 | ax0.plot(f1, dB(Pxx_fft)) 117 | ax0.set_title("Estimated PSD: Standard Periodogram") 118 | 119 | ax1.plot(f_welch, Pxx_welch_mat.T) 120 | ax1.set_title("Estiamted PSD: Welch Averaged Periodogram") 121 | 122 | for ax in (ax0, ax1): 123 | ax.set_xlabel("frequency (Hz)") 124 | ax.set_ylabel("PSD (dB)") 125 | 126 | ax.spines["right"].set_visible(False) 127 | ax.spines["top"].set_visible(False) 128 | 129 | x_max = 60 130 | ax.set_xlim(0, x_max) 131 | ax.set_xticks(np.arange(0, x_max, step=2)) 132 | if ssvep_f0 is not None: 133 | ax.axvline( 134 | ssvep_f0, ls=":", lw=1.5, color="r", label="expected $f^{(0)}_{SSVEP}$" 135 | ) 136 | print(f"Fundamental SSVEP frequency expected at {ssvep_f0}Hz") 137 | ax.legend() 138 | ax.grid() 139 | 140 | plt.tight_layout(pad=1) 141 | 142 | 143 | def stft(x, Nwin, Nfft=None): 144 | import numpy.fft as fft 145 | 146 | """ 147 | Short-time Fourier transform: convert a 1D vector to a 2D array 148 | The short-time Fourier transform (STFT) breaks a long vector into disjoint 149 | chunks (no overlap) and runs an FFT (Fast Fourier Transform) on each chunk. 150 | The resulting 2D array can 151 | Parameters 152 | ---------- 153 | x : array_like 154 | Input signal (expected to be real) 155 | Nwin : int 156 | Length of each window (chunk of the signal). Should be ≪ `len(x)`. 157 | Nfft : int, optional 158 | Zero-pad each chunk to this length before FFT. Should be ≥ `Nwin`, 159 | (usually with small prime factors, for fastest FFT). Default: `Nwin`. 160 | Returns 161 | ------- 162 | out : complex ndarray 163 | `len(x) // Nwin` by `Nfft` complex array representing the STFT of `x`. 164 | 165 | See also 166 | -------- 167 | istft : inverse function (convert a STFT array back to a data vector) 168 | stftbins : time and frequency bins corresponding to `out` 169 | """ 170 | Nfft = Nfft or Nwin 171 | Nwindows = x.size // Nwin 172 | # reshape into array `Nwin` wide, and as tall as possible. This is 173 | # optimized for C-order (row-major) layouts. 174 | arr = np.reshape(x[: Nwindows * Nwin], (-1, Nwin)) 175 | stft = fft.rfft(arr, Nfft) 176 | return stft 177 | 178 | 179 | def stftbins(x, Nwin, Nfft=None, d=1.0): 180 | import numpy.fft as fft 181 | 182 | """ 183 | Time and frequency bins corresponding to short-time Fourier transform. 184 | Call this with the same arguments as `stft`, plus one extra argument: `d` 185 | sample spacing, to get the time and frequency axes that the output of 186 | `stft` correspond to. 187 | Parameters 188 | ---------- 189 | x : array_like 190 | same as `stft` 191 | Nwin : int 192 | same as `stft` 193 | Nfft : int, optional 194 | same as `stft` 195 | d : float, optional 196 | Sample spacing of `x` (or 1 / sample frequency), units of seconds. 197 | Default: 1.0. 198 | Returns 199 | ------- 200 | t : ndarray 201 | Array of length `len(x) // Nwin`, in units of seconds, corresponding to 202 | the first dimension (height) of the output of `stft`. 203 | f : ndarray 204 | Array of length `Nfft`, in units of Hertz, corresponding to the second 205 | dimension (width) of the output of `stft`. 206 | """ 207 | Nfft = Nfft or Nwin 208 | Nwindows = x.size // Nwin 209 | t = np.arange(Nwindows) * (Nwin * d) 210 | f = fft.rfftfreq(Nfft, d) 211 | return t, f 212 | -------------------------------------------------------------------------------- /eeg_lib/logging_server.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request, jsonify 2 | import json 3 | 4 | import time 5 | 6 | app = Flask(__name__) 7 | 8 | DEFAULT_FILENAME = "./logs/log_data.json" 9 | 10 | 11 | def write_json(filename, data): 12 | with open(filename, "w", encoding="utf-8") as f: 13 | json.dump(data, f, indent=4) 14 | 15 | 16 | def read_json(filename): 17 | with open(filename) as f: 18 | return json.load(f) 19 | 20 | 21 | def log_data(payload, filename=None): 22 | 23 | filename = filename or DEFAULT_FILENAME 24 | session_id = payload.get("session_id", f"default_session_{int(time.time())}") 25 | try: 26 | existing_data = read_json(filename) 27 | except FileNotFoundError: 28 | existing_data = {} 29 | 30 | if session_id in existing_data: 31 | existing_data[session_id].append(payload) 32 | del payload["session_id"] 33 | else: 34 | existing_data[session_id] = [payload] 35 | write_json(filename, existing_data) 36 | print(f"Log file {filename} updated successfully.") 37 | 38 | 39 | @app.route("/", methods=["POST"]) 40 | def save_data(): 41 | data = request.get_json(force=True) 42 | if isinstance(data, str): 43 | data = json.loads(data) 44 | if data is not None: 45 | log_data(data) 46 | return jsonify(msg="data stored successfully"), 200 47 | return jsonify(msg="invalid data payload"), 400 48 | 49 | 50 | if __name__ == "__main__": 51 | app.run(debug=True, host="0.0.0.0", port=5001) 52 | -------------------------------------------------------------------------------- /eeg_lib/plotting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def get_x_offsets(n, width=0.2): 6 | "get x offsets of bar centres for grouped bar charts in matplotlib" 7 | if n % 2 == 0: # even 8 | sides = [width * x / 2 for x in range(1, n + 1, 2)] 9 | return [-1 * x for x in sides[::-1]] + sides 10 | else: # odd 11 | sides = [width * x / 2 for x in range(3, n + 1, 2)] 12 | return [-1 * x for x in sides[::-1]] + [0] + sides 13 | 14 | 15 | def grouped_bar(x, Y, width=0.2, xlabel="", ylabel="", legend=None, ax=None): 16 | """ 17 | Expects Y to be num_samples x num_vars. 18 | 19 | x should be a vector repr. the independent variable 20 | """ 21 | fig = None 22 | if ax is None: 23 | fig, ax = plt.subplots(1, figsize=(16, 6)) 24 | 25 | x_offsets = get_x_offsets(Y.shape[1], width=width * 0.75) 26 | 27 | for i in range(len(x_offsets)): 28 | ax.bar(np.arange(len(x)) + x_offsets[i], Y[:, i], width=width) 29 | 30 | ax.set_xticks(np.arange(len(x))) 31 | ax.set_xticklabels(x) 32 | 33 | ax.spines["right"].set_visible(False) 34 | ax.spines["top"].set_visible(False) 35 | 36 | ax.set_xlabel(xlabel) 37 | ax.set_ylabel(ylabel) 38 | 39 | ax.set_axisbelow(True) 40 | ax.yaxis.grid(color="gray", linestyle="dashed", alpha=0.6) 41 | 42 | if legend is not None: 43 | ax.legend(legend) 44 | 45 | return fig, ax 46 | -------------------------------------------------------------------------------- /eeg_lib/styles.mplstyle: -------------------------------------------------------------------------------- 1 | axes.titlesize : 18 2 | axes.labelsize : 16 3 | lines.linewidth : 1.5 4 | lines.markersize : 10 5 | xtick.labelsize : 12 6 | ytick.labelsize : 12 -------------------------------------------------------------------------------- /eeg_lib/synthetic.py: -------------------------------------------------------------------------------- 1 | """ 2 | A set of functions to generate synthetic signals that mimic EEG signals. In particular, SSVEP sinsusoidal signals embedded in noise. 3 | """ 4 | from .utils import standardise_ssvep_tensor 5 | import numpy as np 6 | 7 | 8 | def synth_x(f, Ns, noise_power=0.5, fs=250): 9 | """ 10 | generate a synthetic signal vector 11 | 12 | args: 13 | Ns [int]: number of samples (time samples) 14 | noise_power [float]: variance of WGN noise distribution 15 | """ 16 | t = np.arange(0, Ns / fs, step=1 / fs) 17 | return np.sin(2 * np.pi * f * t) + np.random.normal( 18 | size=Ns, loc=0, scale=noise_power 19 | ) 20 | 21 | 22 | def synth_X(f, Nc, Ns, noise_power=0.5, fs=200, f_std=0.02, noise_std=0.2): 23 | """ 24 | Generate a matrix of several variations of the same target signal. This is used 25 | to simulate the measurement of a common signal over multiple EEG channels 26 | that have different SNR characteristics. 27 | 28 | args: 29 | f [float]: target frequency of synthetic signal (Hz) 30 | Nc [int]: number of channels 31 | Ns [int]: number of samples (time samples) 32 | noise_power [float]: variance of WGN noise distribution 33 | fs [float]: sampling frequency (Hz) 34 | f_std [float]: standard dev. of freq. in generated signal across channels to simulate interference from other frequency components over different channels 35 | noise_std [float]: standard dev. of noise across channels 36 | """ 37 | X = [] 38 | for i in range(Nc): # simulate noisy sinusoids with varying SNR across Nc channels 39 | f_i = f * (1 + np.random.normal(scale=f_std)) 40 | sigma_i = noise_power * (1 + np.random.normal(scale=noise_std)) 41 | x = synth_x(f_i, Ns, noise_power=sigma_i) 42 | 43 | x += 0.2 * synth_x( 44 | f_i * 1.05, Ns 45 | ) # add extraneous neighbouring signals (task unrelated) 46 | x += 0.1 * synth_x(f_i * 1.1, Ns) 47 | X.append(x) 48 | 49 | return np.array(X) 50 | 51 | 52 | def synth_data_tensor(stim_freqs, Ns, Nc, Nt, noise_power, fs, Nh=3): 53 | """ 54 | Generate a synthetic 4th order tensor (Chi) of dim. Nf x Nc x Ns x Nt 55 | 56 | args: 57 | stim_freqs [float]: stimulus frequencies of interest (SSVEP). `Nf = len(stim_freqs)` 58 | Nc [int]: number of channels 59 | Ns [int]: number of samples (time samples) 60 | Nt [int]: number of trials 61 | Nh [int]: number of harmonics in sinusoidal ref. signal 62 | noise_power [float]: variance of WGN noise distribution 63 | fs [float]: sampling frequency (Hz) 64 | """ 65 | out_tensor = [] 66 | for f in stim_freqs: 67 | out_tensor.append( 68 | np.array( 69 | [synth_X(f, Nc, Ns, noise_power=noise_power) for i in range(Nt)] 70 | ).transpose(1, 2, 0) 71 | ) 72 | X = np.array(out_tensor) 73 | return standardise_ssvep_tensor(X) 74 | -------------------------------------------------------------------------------- /eeg_lib/trca.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import pearsonr 3 | 4 | # Task-related component analysis: Tanaka et al 5 | class TRCA: 6 | def __init__(self): 7 | self.W = None 8 | self.lam = None 9 | self.Y = None 10 | self.X = None 11 | 12 | def fit(self, X): 13 | """ 14 | :param 15 | X: data tensor (Nc x Ns x Nt) 16 | """ 17 | Nc, Ns, Nt = X.shape 18 | 19 | S = np.zeros((Nc, Nc)) # inter-trial (inter-block) covariance matrix 20 | 21 | # computation of correlation matrices: 22 | for i in range(Nc): 23 | for j in range(Nc): 24 | for k in range(Nt): 25 | for l in range(Nt): 26 | if k != l: # compare blocks (trials) l and k 27 | xi = X[i, :, k].reshape(1, -1) 28 | xj = X[j, :, l].reshape(1, -1) 29 | S[i, j] += np.dot( 30 | (xi - np.mean(xi, axis=1)), (xj - np.mean(xj, axis=1)).T 31 | ) 32 | 33 | X_bar = X.reshape((Nc, Ns * Nt)) - np.tile( 34 | X.reshape((Nc, Ns * Nt)).mean(axis=1).reshape(Nc, 1), (1, Ns * Nt) 35 | ) 36 | 37 | Q = np.dot(X_bar, X_bar.T) # Nc x Nc data covariance matrix 38 | lam, W = np.linalg.eig(np.dot(np.linalg.inv(Q), S)) 39 | 40 | print(lam, W) 41 | 42 | i = np.argsort(np.real(lam))[::-1] # get order of largest eigenvalues in lam 43 | 44 | self.X = X 45 | self.W = W[:, i] 46 | self.lam = lam[i] 47 | self.Y = np.dot(self.W[:, 0].T, X_bar) 48 | 49 | def compute_corr(self, X_test): 50 | X_av = self.X.mean(axis=-1) 51 | w = self.W[:, 0] # get eig. vector corresp to largest eig val 52 | return pearsonr( 53 | np.squeeze(w.T.dot(X_test)), np.squeeze(np.squeeze(w.T.dot(X_av))) 54 | )[0] 55 | 56 | def get_eig(self): 57 | return self.lam, self.W 58 | 59 | 60 | class TRCA_SSVEP: 61 | def __init__(self, stim_freqs): 62 | self.stim_freqs = stim_freqs 63 | self.models = { 64 | f: TRCA() for f in stim_freqs 65 | } # init independent TRCA models per stim freq 66 | 67 | def fit(self, X_ssvep): 68 | """ 69 | Fit the independent Nf TRCA models using input data tensor `X_ssvep` 70 | 71 | :param 72 | X_ssvep: 4th order data tensor (Nf x Nc x Ns x Nt) 73 | """ 74 | assert ( 75 | len(X_ssvep.shape) == 4 76 | ), "Expected a 4th order data tensor with shape (Nf x Nc x Ns x Nt)" 77 | assert ( 78 | len(self.stim_freqs) == X_ssvep.shape[0] 79 | ), "Length of supplied stim freqs does not match first dimension of input data" 80 | 81 | for i, f in enumerate(self.stim_freqs): 82 | self.models[f].fit(X_ssvep[i, :, :, :]) 83 | 84 | def compute_corr(self, X_test): 85 | assert len(X_test.shape) == 2, "Expected a matrix with shape (Nc x Ns)" 86 | 87 | return {f: self.models[f].compute_corr(X_test) for f in self.stim_freqs} 88 | 89 | def get_eig(self): 90 | return {f: self.models[f].get_eig() for f in self.stim_freqs} 91 | -------------------------------------------------------------------------------- /eeg_lib/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import json 4 | 5 | from numpy import linalg as LA 6 | 7 | dB = lambda x: 10 * np.log10(x) # convert mag. to dB 8 | 9 | 10 | def solve_gen_eig_prob(A, B, eps=1e-8): 11 | """ 12 | Solves the generalised eigenvalue problem of the form: 13 | Aw = \lambda*Bw 14 | 15 | Note: can be validated against `scipy.linalg.eig(A, b=B)` 16 | 17 | Ref: 18 | 'Eigenvalue and Generalized Eigenvalue Problems: Tutorial (2019)' 19 | Benyamin Ghojogh and Fakhri Karray and Mark Crowley 20 | arXiv 1903.11240 21 | 22 | """ 23 | Lam_b, Phi_b = LA.eig(B) # eig decomp of B alone 24 | Lam_b = np.eye(len(Lam_b)) * Lam_b # convert to diagonal matrix of eig vals 25 | 26 | Lam_b_sq = np.nan_to_num(Lam_b ** 0.5) + eps * np.eye(len(Lam_b)) 27 | Phi_b_hat = Phi_b.dot(LA.inv(Lam_b_sq)) 28 | A_hat = Phi_b_hat.T.dot(A).dot(Phi_b_hat) 29 | Lam_a, Phi_a = LA.eig(A_hat) 30 | Lam_a = np.eye(len(Lam_a)) * Lam_a 31 | 32 | Lam = Lam_a 33 | Phi = Phi_b_hat.dot(Phi_a) 34 | 35 | return np.diag(Lam), Phi 36 | 37 | 38 | def inv_square(A): 39 | """ 40 | Compute inverse square root of a matrix using Cholesky decomp. 41 | 42 | Requires A to be positive definite. 43 | """ 44 | return LA.inv(LA.cholesky(A)) 45 | 46 | 47 | def save_data_npz(fname, data, **kwargs): 48 | 49 | if not isinstance(data, np.ndarray): 50 | data = data.values 51 | 52 | np.savez(fname, data=data, **kwargs) 53 | 54 | 55 | def load_df(fname, key="data", cols=None): 56 | if cols is None: 57 | cols = [f"chan{i}" for i in range(1, 5)] 58 | df = pd.DataFrame(np.load(fname)[key], columns=cols) 59 | return df 60 | 61 | 62 | def standardise(X): 63 | axis = np.argmax(X.shape) 64 | return (X - np.mean(X, axis=axis)) / np.std(X, axis=axis) 65 | 66 | 67 | def standardise_ssvep_tensor(X): 68 | # Given a obs matrix for given f, and trial, rows (channels) should all be zero mean and unit std dev 69 | 70 | Nf, Nc, Ns, Nt = X.shape 71 | 72 | for n in range(Nf): 73 | for t in range(Nt): 74 | obs = X[n, :, :, t] 75 | mu = np.broadcast_to(obs.mean(axis=1), (Ns, Nc)).T 76 | sigma = np.broadcast_to(obs.std(axis=1), (Ns, Nc)).T 77 | X[n, :, :, t] = (obs - mu) / sigma 78 | 79 | return X 80 | 81 | 82 | def resample(X, factor): 83 | idx_rs = np.arange(0, len(X) - 1, factor) 84 | return X[idx_rs] 85 | 86 | 87 | def load_trials(path_pattern, verbose=False): 88 | all_files = glob.glob(path_pattern) 89 | data = [] 90 | 91 | min_len = ( 92 | 10e6 # trials lengths will be very similar but may differ by 1 or 2 samples 93 | ) 94 | for filename in all_files: 95 | if verbose: 96 | print(f"Loading file {filename}") 97 | f = np.load(filename, allow_pickle=True) 98 | data.append(f["data"]) 99 | if len(f["data"]) < min_len: 100 | min_len = len(f["data"]) 101 | 102 | return np.array([trial[: min_len - 1] for trial in data]) 103 | 104 | 105 | def write_json(filename, data): 106 | with open(filename, "w") as f: 107 | json.dump(data, f) 108 | 109 | 110 | def read_json(filename): 111 | with open(filename) as f: 112 | return json.load(f) 113 | -------------------------------------------------------------------------------- /esp32-cmake.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export BUILD_DIR="${BUILD_DIR:=$(~/)}" 3 | 4 | echo "--- CLONING ULAB ---" 5 | cd $BUILD_DIR 6 | git clone --depth 1 https://github.com/v923z/micropython-ulab.git ulab 7 | 8 | echo "--- CLONING MICROPYTHON ---" 9 | git clone --depth 1 https://github.com/micropython/micropython.git 10 | 11 | echo "--- CLONING ESP-IDF ---" 12 | cd $BUILD_DIR/micropython/ 13 | git clone --depth 1 -b v4.0.2 --recursive https://github.com/espressif/esp-idf.git 14 | 15 | echo "--- INSTALL ESP-IDF ---" 16 | cd $BUILD_DIR/micropython/esp-idf 17 | ./install.sh 18 | . ./export.sh 19 | 20 | echo "--- MPY-CROSS ---" 21 | cd $BUILD_DIR/micropython/mpy-cross 22 | make 23 | 24 | echo "--- ESP32 SUBMODULES ---" 25 | cd $BUILD_DIR/micropython/ports/esp32 26 | make submodules 27 | 28 | echo "--- PATCH MAKEFILE ---" 29 | cp $BUILD_DIR/micropython/ports/esp32/MakeFile $BUILD_DIR/micropython/ports/esp32/MakeFileOld 30 | echo "BOARD = GENERIC" > $BUILD_DIR/micropython/ports/esp32/MakeFile 31 | echo "USER_C_MODULES = \$(BUILD_DIR)/ulab/code/micropython.cmake" >> $BUILD_DIR/micropython/ports/esp32/MakeFile 32 | cat $BUILD_DIR/micropython/ports/esp32/MakeFileOld >> $BUILD_DIR/micropython/ports/esp32/MakeFile 33 | 34 | echo "--- MAKE ---" 35 | make -------------------------------------------------------------------------------- /experimentation/Eigenvalue Algorithms.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 3, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "def hessenberg(A):\n", 20 | " # init F\n", 21 | " F = np.zeros(A.shape) \n", 22 | " for idx in range(1, A.shape[0]):\n", 23 | " F[idx, idx - 1] = 1.0\n", 24 | " \n", 25 | " # init Z\n", 26 | " Z = np.zeros(A.shape)\n", 27 | " Z[0, 0] = 1.0\n", 28 | " \n", 29 | " # recursive formula: Fik = (A * Zk)_i - sum_j=1^i-1 {Fjk * Zij} / Z_ii\n", 30 | " for k in range(1, A.shape[1] + 1):\n", 31 | " Azk = np.matmul(A, Z[:, k - 1])\n", 32 | " \n", 33 | " for i in range(0, k):\n", 34 | " temp = 0.0\n", 35 | " for j in range(0, i):\n", 36 | " temp += F[j, k - 1] * Z[i, j]\n", 37 | " F[i, k - 1] = (Azk[i] - temp) / Z[i, i]\n", 38 | " \n", 39 | " if k < A.shape[1]: # to get the last row of F, but here Z[:, k] would be out of range\n", 40 | " Z[:, k] = Azk[:]\n", 41 | " for t in range(0, k):\n", 42 | " Z[:, k] -= F[t, k - 1] * Z[:, t]\n", 43 | " \n", 44 | " return F, Z\n", 45 | "\n", 46 | "def solve_qr(A, iterations=30):\n", 47 | "\n", 48 | " Ak = A\n", 49 | " Q_bar = np.eye(*Ak.shape)\n", 50 | "\n", 51 | " for k in range(iterations):\n", 52 | " Qk, Rk = np.linalg.qr(Ak)\n", 53 | " Ak = np.dot(Rk, Qk)\n", 54 | " Q_bar = Q_bar.dot(Qk)\n", 55 | "\n", 56 | " lam = np.diag(Ak)\n", 57 | " return lam, Q_bar\n", 58 | "\n", 59 | "def solve_gen_eig_prob(A, B, eps=1e-6):\n", 60 | " \"\"\"\n", 61 | " Solves the generalised eigenvalue problem of the form:\n", 62 | " Aw = \\lambda*Bw\n", 63 | " \n", 64 | " Note: can be validated against `scipy.linalg.eig(A, b=B)`\n", 65 | " \n", 66 | " Ref: \n", 67 | " 'Eigenvalue and Generalized Eigenvalue Problems: Tutorial (2019)'\n", 68 | " Benyamin Ghojogh and Fakhri Karray and Mark Crowley\n", 69 | " arXiv 1903.11240\n", 70 | "\n", 71 | " \"\"\"\n", 72 | " Lam_b, Phi_b = np.linalg.eig(B) # eig decomp of B alone\n", 73 | " Lam_b = np.eye(len(Lam_b))*Lam_b # convert to diagonal matrix of eig vals\n", 74 | " \n", 75 | " Lam_b_sq = replace_nan(Lam_b**0.5)+np.eye(len(Lam_b))*eps\n", 76 | " Phi_b_hat = np.dot(Phi_b, np.linalg.inv(Lam_b_sq))\n", 77 | " A_hat = np.dot(np.dot(Phi_b_hat.transpose(), A), Phi_b_hat)\n", 78 | " Lam_a, Phi_a = np.linalg.eig(A_hat)\n", 79 | " Lam_a = np.eye(len(Lam_a))*Lam_a\n", 80 | " \n", 81 | " Lam = Lam_a\n", 82 | " Phi = np.dot(Phi_b_hat, Phi_a)\n", 83 | " \n", 84 | " return np.diag(Lam), Phi\n", 85 | "\n", 86 | "def solve_eig_qr(A, n_eig, lam_iterations=5):\n", 87 | " # !! note: eigenvectors can only be found reliably if A is symmetric\n", 88 | " Ak = A\n", 89 | " n_eig = min(n_eig, min(A.shape))\n", 90 | "\n", 91 | " for k in range(lam_iterations):\n", 92 | " Qk, Rk = np.linalg.qr(Ak)\n", 93 | " Ak = np.dot(Rk, Qk)\n", 94 | "\n", 95 | " lam = np.diag(Ak) # get eigenvalues\n", 96 | " V = []\n", 97 | " for l in lam[:n_eig]: # now find `n_eig` eigenvectors\n", 98 | " A_null = (A - np.eye(A.shape[0])*l).transpose()\n", 99 | " Q, R = np.linalg.qr(A_null) # compute null space of (A-lam*I) to get eigenvector\n", 100 | " V.append(Q[:, -1])\n", 101 | " return lam, np.array(V).transpose()\n", 102 | "\n", 103 | "def power_iteration(A, iterations):\n", 104 | " \"\"\"\n", 105 | " Iterative algo. to find the eigenvector of a matrix A corresponding to the largest\n", 106 | " eigenvalue.\n", 107 | " \n", 108 | " TODO: Establish some measure or heuristic of min number of iterations required\n", 109 | " \"\"\"\n", 110 | " # choose random initial vector to reduce risk of choosing one orthogonal to \n", 111 | " # target eigen vector\n", 112 | " b_k = np.array([urandom.random() for i in range(len(A))])\n", 113 | "\n", 114 | " for _ in range(iterations):\n", 115 | " b_k1 = np.dot(A, b_k)\n", 116 | " b_k1_norm = np.linalg.norm(b_k1)\n", 117 | " # re normalize the vector\n", 118 | " b_k = b_k1 / b_k1_norm\n", 119 | "\n", 120 | " return b_k1_norm, b_k\n", 121 | "\n", 122 | "def max_eig(A, iterations, numeric_method='qr'):\n", 123 | " \"\"\"\n", 124 | " Function to return the largest eigenvalue of a matrix and its corresponding eigenvector.\n", 125 | " \n", 126 | " A must be square but need not be symmetric. Tries to first use uLab `np.linalg.eig`\n", 127 | " that is better optimised but requires a symmetric matrix. Failing this, power iteration \n", 128 | " algorithm is used.\n", 129 | " \"\"\"\n", 130 | " try:\n", 131 | " lam, V = np.linalg.eig(A)\n", 132 | " v = V[:, np.argmax(lam)]\n", 133 | " except ValueError:\n", 134 | " if numeric_method == 'power_iteration':\n", 135 | " lam, v = power_iteration(A, iterations)\n", 136 | " else:\n", 137 | " if numeric_method != 'qr':\n", 138 | " print(\"Unknown `numeric_method` arg: defaulting to QR solver\")\n", 139 | " lam, v = solve_eig_qr(A, 1, lam_iterations=iterations)\n", 140 | " lam = lam[0] # only need first eigen val (largest returned first)\n", 141 | " v = v[:, 0] # only first eig vector \n", 142 | " \n", 143 | " return lam, v\n", 144 | "\n", 145 | "def replace_nan(A, rep=0):\n", 146 | " return np.where(np.isfinite(A), A, rep)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 4, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "class CCA():\n", 156 | " \n", 157 | " def __init__(self, stim_freqs, fs, Nh=2):\n", 158 | " self.Nh = Nh\n", 159 | " self.stim_freqs = stim_freqs\n", 160 | " self.fs = fs\n", 161 | " \n", 162 | " def compute_corr(self, X_test): \n", 163 | " result = {}\n", 164 | " Cxx = np.dot(X_test, X_test.transpose()) # precompute data auto correlation matrix\n", 165 | " for f in self.stim_freqs:\n", 166 | " Y = harmonic_reference(f, self.fs, np.max(X_test.shape), Nh=self.Nh, standardise_out=True)\n", 167 | " rho = self.cca_eig(X_test, Y, Cxx=Cxx) # canonical variable matrices. Xc = X^T.W_x\n", 168 | " result[f] = rho\n", 169 | " return result\n", 170 | " \n", 171 | " @staticmethod\n", 172 | " def cca_eig(X, Y, Cxx=None):\n", 173 | " if Cxx is None:\n", 174 | " Cxx = np.dot(X, X.transpose()) # auto correlation matrix\n", 175 | " Cyy = np.dot(Y, Y.transpose()) \n", 176 | " Cxy = np.dot(X, Y.transpose()) # cross correlation matrix\n", 177 | " Cyx = np.dot(Y, X.transpose()) # same as Cxy.T\n", 178 | "\n", 179 | " M1 = np.dot(np.linalg.inv(Cxx), Cxy) # intermediate result\n", 180 | " M2 = np.dot(np.linalg.inv(Cyy), Cyx)\n", 181 | "\n", 182 | " lam, _ = max_eig(np.dot(M1, M2), 20)\n", 183 | " return np.sqrt(lam)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "### CCA Example Problem\n", 191 | "Testing finding canonical correlations using eigenvalues of covariance matrices. `X` is a data matrix and `Y` is a matrix of reference signals with 2 harmonics" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 5, 197 | "metadata": {}, 198 | "outputs": [ 199 | { 200 | "name": "stdout", 201 | "output_type": "stream", 202 | "text": [ 203 | "(1, 150) (4, 150)\n" 204 | ] 205 | } 206 | ], 207 | "source": [ 208 | "from eeg_lib.synthetic import synth_X\n", 209 | "from eeg_lib.cca import cca_reference\n", 210 | "\n", 211 | "Ns = 150\n", 212 | "\n", 213 | "X = synth_X(7, 1, Ns, noise_power=0.2, f_std=0.04)\n", 214 | "Y = cca_reference([15], 250, Ns, Nh=2)\n", 215 | "\n", 216 | "# X = X.T\n", 217 | "# Y = Y.T\n", 218 | "\n", 219 | "print(X.shape, Y.shape)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 6, 225 | "metadata": {}, 226 | "outputs": [ 227 | { 228 | "data": { 229 | "text/plain": [ 230 | "array([0.07173047])" 231 | ] 232 | }, 233 | "execution_count": 6, 234 | "metadata": {}, 235 | "output_type": "execute_result" 236 | } 237 | ], 238 | "source": [ 239 | "from sklearn.cross_decomposition import CCA as CCA_sklearn\n", 240 | "\n", 241 | "n = min(Y.T.shape[1], X.T.shape[1])\n", 242 | "cca = CCA_sklearn(n_components=n)\n", 243 | "cca.fit(X.T, Y.T)\n", 244 | "\n", 245 | "X_c, Y_c = cca.transform(X.T, Y.T)\n", 246 | "result = np.corrcoef(X_c.T, Y_c.T).diagonal(offset=n)\n", 247 | "result" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 7, 253 | "metadata": {}, 254 | "outputs": [ 255 | { 256 | "data": { 257 | "text/plain": [ 258 | "array([0.07133839])" 259 | ] 260 | }, 261 | "execution_count": 7, 262 | "metadata": {}, 263 | "output_type": "execute_result" 264 | } 265 | ], 266 | "source": [ 267 | "Cxx = np.dot(X, X.transpose()) # auto correlation matrix\n", 268 | "Cyy = np.dot(Y, Y.transpose()) \n", 269 | "Cxy = np.dot(X, Y.transpose()) # cross correlation matrix\n", 270 | "Cyx = np.dot(Y, X.transpose()) # same as Cxy.T\n", 271 | "\n", 272 | "M1 = np.dot(np.linalg.inv(Cxx), Cxy) # intermediate result\n", 273 | "M2 = np.dot(np.linalg.inv(Cyy), Cyx)\n", 274 | "M = np.dot(M1, M2)\n", 275 | "\n", 276 | "lam, V = np.linalg.eig(M)\n", 277 | "np.sqrt(lam)" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 228, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "import ujson as json\n", 287 | "\n", 288 | "data = {'X': X.tolist(), 'Y': Y.tolist()}\n", 289 | "with open('xy.json', 'w') as jsonfile:\n", 290 | " json.dump(data, jsonfile)" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 216, 296 | "metadata": {}, 297 | "outputs": [ 298 | { 299 | "data": { 300 | "text/plain": [ 301 | "array([0.92299, 0.07742, 0.01606, 0.00484])" 302 | ] 303 | }, 304 | "execution_count": 216, 305 | "metadata": {}, 306 | "output_type": "execute_result" 307 | } 308 | ], 309 | "source": [ 310 | "lam, V = solve_qr(M, iterations=100)\n", 311 | "np.sqrt(lam).round(5)" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 195, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "Cxx = X.dot(X.T)\n", 321 | "Cxy = X.dot(Y.T)\n", 322 | "Cyy = Y.dot(Y.T)\n", 323 | "Cyx = Y.dot(X.T)\n", 324 | "\n", 325 | "def block_diag(X, Y, reverse=False):\n", 326 | " if not reverse:\n", 327 | " X = np.concatenate([X, np.zeros_like(X)], axis=1)\n", 328 | " Y = np.concatenate([np.zeros_like(Y), Y], axis=1)\n", 329 | " else:\n", 330 | " X = np.concatenate([np.zeros_like(X), X], axis=1)\n", 331 | " Y = np.concatenate([Y, np.zeros_like(Y)], axis=1)\n", 332 | " return np.concatenate([X, Y], axis=0)\n", 333 | "\n", 334 | "A = block_diag(Cxy, Cyx, reverse=True)\n", 335 | "B = block_diag(Cxx, Cyy)\n", 336 | "\n", 337 | "lam, Phi = solve_gen_eig_prob(A, B, eps=1e-12)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 197, 343 | "metadata": {}, 344 | "outputs": [ 345 | { 346 | "data": { 347 | "text/plain": [ 348 | "array([ 0.96230964, -0.96230964, 0.6550056 , -0.6550056 , 0.03841503,\n", 349 | " -0.03841503, 0.00107915, -0.00107915])" 350 | ] 351 | }, 352 | "execution_count": 197, 353 | "metadata": {}, 354 | "output_type": "execute_result" 355 | } 356 | ], 357 | "source": [ 358 | "lam" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 189, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "# args = [X, Y]\n", 368 | "# Z = np.zeros((sum(arg.shape[0] for arg in args), sum(arg.shape[1] for arg in args)))\n", 369 | "# origin = [0, 0]\n", 370 | "# for arg in args:\n", 371 | "# x0 = origin[0]\n", 372 | "# y0 = origin[1]\n", 373 | "# Z[x0:x0+arg.shape[0], y0:y0+arg.shape[1]] = arg\n", 374 | "# origin[0] += arg.shape[0]\n", 375 | "# origin[1] += arg.shape[1]\n", 376 | "\n", 377 | "# args = [X, Y]\n", 378 | "# Z = np.zeros((sum(arg.shape[0] for arg in args), sum(arg.shape[1] for arg in args)))\n", 379 | "# origin = Z.shape\n", 380 | "# for arg in args:\n", 381 | "# x0 = origin[0]\n", 382 | "# y0 = origin[1]\n", 383 | "# Z[x0:x0-arg.shape[0], y0:y0-arg.shape[1]] = arg\n", 384 | "# origin[0] -= arg.shape[0]\n", 385 | "# origin[1] -= arg.shape[1]" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 79, 391 | "metadata": {}, 392 | "outputs": [ 393 | { 394 | "data": { 395 | "text/plain": [ 396 | "array([[ 2, 95, -38, 18, 5],\n", 397 | " [ 1, 47, -19, 8, 1],\n", 398 | " [ 2, 151, -69, 28, 4],\n", 399 | " [ -1, 218, -88, 34, 6],\n", 400 | " [ 0, -208, 84, -34, -5]])" 401 | ] 402 | }, 403 | "execution_count": 79, 404 | "metadata": {}, 405 | "output_type": "execute_result" 406 | } 407 | ], 408 | "source": [ 409 | "A = np.array([2, 95, -38, 18, 5, 1, 47, -19, 8, 1, 2, 151, -69, 28, 4, -1, 218, -88, 34, 6, 0, -208, 84, -34, -5])\n", 410 | "A = A.reshape((5,5))\n", 411 | "A" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 80, 417 | "metadata": {}, 418 | "outputs": [ 419 | { 420 | "name": "stdout", 421 | "output_type": "stream", 422 | "text": [ 423 | "Eigenvalues: \n", 424 | "[ 25.57771123+0.j -12.31982385+0.j -3.21680685+1.4877582j\n", 425 | " -3.21680685-1.4877582j 2.17572632+0.j ], \n", 426 | "\n", 427 | " Eigenvectors (rounded): \n", 428 | "[[ 0.3 +0.j -0.17+0.j 0.08+0.16j 0.08-0.16j 0.57+0.j ]\n", 429 | " [ 0.16+0.j -0.08+0.j -0.01+0.01j -0.01-0.01j 0. +0.j ]\n", 430 | " [ 0.41+0.j -0.51+0.j -0.34+0.02j -0.34-0.02j -0.11+0.j ]\n", 431 | " [ 0.58+0.j -0.67+0.j -0.83+0.j -0.83-0.j -0.43+0.j ]\n", 432 | " [-0.62+0.j 0.51+0.j 0.38-0.11j 0.38+0.11j 0.69+0.j ]]\n" 433 | ] 434 | } 435 | ], 436 | "source": [ 437 | "lam_ref, V_ref = np.linalg.eig(A)\n", 438 | "print(f\"Eigenvalues: \\n{lam_ref}, \\n\\n Eigenvectors (rounded): \\n{V_ref.round(2)}\")" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": 81, 444 | "metadata": {}, 445 | "outputs": [ 446 | { 447 | "data": { 448 | "text/plain": [ 449 | "array([ 25.57771123, -12.31982385, -1.81028516, -4.62332854,\n", 450 | " 2.17572632])" 451 | ] 452 | }, 453 | "execution_count": 81, 454 | "metadata": {}, 455 | "output_type": "execute_result" 456 | } 457 | ], 458 | "source": [ 459 | "lam, V = solve_qr(A, iterations=100)\n", 460 | "lam" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": 82, 466 | "metadata": {}, 467 | "outputs": [ 468 | { 469 | "data": { 470 | "text/plain": [ 471 | "array([[-0.3 , -0.54, -0.49, -0.42, -0.45],\n", 472 | " [-0.16, -0.34, 0.27, 0.74, -0.48],\n", 473 | " [-0.41, 0.49, -0.66, 0.39, 0.03],\n", 474 | " [-0.58, 0.44, 0.48, -0.33, -0.35],\n", 475 | " [ 0.62, 0.39, -0.15, -0.06, -0.66]])" 476 | ] 477 | }, 478 | "execution_count": 82, 479 | "metadata": {}, 480 | "output_type": "execute_result" 481 | } 482 | ], 483 | "source": [ 484 | "V.round(2)" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": 83, 490 | "metadata": {}, 491 | "outputs": [ 492 | { 493 | "data": { 494 | "text/plain": [ 495 | "array([ 25.57771123, -12.31982385, -3.09449956, -3.33911414,\n", 496 | " 2.17572632])" 497 | ] 498 | }, 499 | "execution_count": 83, 500 | "metadata": {}, 501 | "output_type": "execute_result" 502 | } 503 | ], 504 | "source": [ 505 | "A_f, _ = hessenberg(A)\n", 506 | "lam_f, V_f = solve_qr(A_f, iterations=100)\n", 507 | "lam_f" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": 84, 513 | "metadata": {}, 514 | "outputs": [ 515 | { 516 | "data": { 517 | "text/plain": [ 518 | "array([[ 0.882, -0.205, 0.424, 0.001, -0. ],\n", 519 | " [ 0.471, 0.413, -0.779, -0.003, -0. ],\n", 520 | " [-0.015, 0.887, 0.461, 0.013, 0. ],\n", 521 | " [ 0.001, -0.01 , -0.008, 0.981, 0.194],\n", 522 | " [ 0. , 0.002, 0.002, -0.194, 0.981]])" 523 | ] 524 | }, 525 | "execution_count": 84, 526 | "metadata": {}, 527 | "output_type": "execute_result" 528 | } 529 | ], 530 | "source": [ 531 | "V_f.round(3)" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": 14, 537 | "metadata": {}, 538 | "outputs": [ 539 | { 540 | "data": { 541 | "text/plain": [ 542 | "array([[ 2, 1, 778, 35107, -15046],\n", 543 | " [ 1, 1, 389, 17954, -3009],\n", 544 | " [ 0, 1, -36, -1527, 354],\n", 545 | " [ 0, 0, 1, 42, -9],\n", 546 | " [ 0, 0, 0, 1, 0]])" 547 | ] 548 | }, 549 | "execution_count": 14, 550 | "metadata": {}, 551 | "output_type": "execute_result" 552 | } 553 | ], 554 | "source": [ 555 | "F, Z = hessenberg(A)\n", 556 | "F.astype(int)" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": 85, 562 | "metadata": {}, 563 | "outputs": [ 564 | { 565 | "data": { 566 | "text/plain": [ 567 | "(array([0.02148729, 0.14447535, 1.00472133, 0.99984663]),\n", 568 | " array([[ 0.02902728, 0.02892596, 0.02260202, 0.02290183],\n", 569 | " [ 0.7939542 , 0.79404442, 0.79942807, 0.79912809],\n", 570 | " [-0.53426395, -0.53399658, -0.51715765, -0.51825175],\n", 571 | " [-0.28871471, -0.28897127, -0.3048801 , -0.30378435]]))" 572 | ] 573 | }, 574 | "execution_count": 85, 575 | "metadata": {}, 576 | "output_type": "execute_result" 577 | } 578 | ], 579 | "source": [ 580 | "X = np.array([[0.0, 0.2552531, 0.4935954, 0.6992362, 0.8585516, 0.9609866, 0.9997549, 0.972288, 0.8804055, 0.7301948],\n", 581 | " [0.0, 0.2651061, 0.5112409, 0.7207904, 0.8787591, 0.9738424, 0.9992362, 0.9531231, 0.838803, 0.6644569],\n", 582 | " [0.0, 0.2634635, 0.5083104, 0.7172394, 0.8754874, 0.9718722, 0.9995833, 0.9566626, 0.8461428, 0.6758333],\n", 583 | " [0.0, 0.2671577, 0.5148946, 0.7252015, 0.8827904, 0.9762053, 0.9986557, 0.9485094, 0.8294118, 0.6500207]])\n", 584 | "\n", 585 | "# Y = cca_reference([7], 200, 10, Nh=2)\n", 586 | "\n", 587 | "Y = np.array([[-2.171207, -1.338523, -0.5880827, 0.04396701, 0.5271821, 0.8382883, 0.9623005, 0.8932458, 0.6344502, 0.1983788],\n", 588 | " [1.305342, 1.170641, 0.9533603, 0.6639652, 0.3163952, -0.07260892, -0.4843098, -0.8988775, -1.296343, -1.657563],\n", 589 | " [0.2679634, 0.7796801, 1.073691, 1.094033, 0.8368343, 0.3510505, -0.2708505, -0.9104931, -1.446123, -1.775786],\n", 590 | " [1.850696, 1.432374, 0.8242437, 0.1420597, -0.4843266, -0.9356859, -1.126103, -1.019333, -0.6356997, -0.04822552]])\n", 591 | "\n", 592 | "Cxx = X.dot(X.T)\n", 593 | "Cxy = X.dot(Y.T)\n", 594 | "Cyy = Y.dot(Y.T)\n", 595 | "Cyx = Y.dot(X.T)\n", 596 | "\n", 597 | "M1 = np.linalg.inv(Cxx).dot(Cxy)\n", 598 | "M2 = np.linalg.inv(Cyy).dot(Cyx)\n", 599 | "\n", 600 | "M = M1.dot(M2)\n", 601 | "\n", 602 | "np.linalg.eig(M)" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": 86, 608 | "metadata": {}, 609 | "outputs": [], 610 | "source": [ 611 | "M_f, _ = hessenberg(M)" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": 87, 617 | "metadata": {}, 618 | "outputs": [ 619 | { 620 | "data": { 621 | "text/plain": [ 622 | "array([1.0048447 , 0.99972344, 0.14443503, 0.02152743])" 623 | ] 624 | }, 625 | "execution_count": 87, 626 | "metadata": {}, 627 | "output_type": "execute_result" 628 | } 629 | ], 630 | "source": [ 631 | "lam, V = solve_qr(M_f, iterations=100)\n", 632 | "lam" 633 | ] 634 | }, 635 | { 636 | "cell_type": "code", 637 | "execution_count": 88, 638 | "metadata": {}, 639 | "outputs": [ 640 | { 641 | "data": { 642 | "text/plain": [ 643 | "array([[ 1. , 0.031, 0. , -0. ],\n", 644 | " [-0. , -0. , 0.031, -1. ],\n", 645 | " [ 0. , 0. , -1. , -0.031],\n", 646 | " [ 0.031, -1. , -0. , -0. ]])" 647 | ] 648 | }, 649 | "execution_count": 88, 650 | "metadata": {}, 651 | "output_type": "execute_result" 652 | } 653 | ], 654 | "source": [ 655 | "V.round(3)" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": 89, 661 | "metadata": {}, 662 | "outputs": [ 663 | { 664 | "data": { 665 | "text/plain": [ 666 | "array([[-0.02259485, -0.0228978 , -0.028926 , -0.02902725],\n", 667 | " [-0.79943385, -0.79915401, -0.79404439, -0.79395422],\n", 668 | " [ 0.51713865, 0.51812389, 0.53399668, 0.53426387],\n", 669 | " [ 0.30489771, 0.30393452, 0.28897118, 0.28871478]])" 670 | ] 671 | }, 672 | "execution_count": 89, 673 | "metadata": {}, 674 | "output_type": "execute_result" 675 | } 676 | ], 677 | "source": [ 678 | "V2 = []\n", 679 | "for l in lam: # now find `n_eig` eigenvectors\n", 680 | " A_null = (M - np.eye(M.shape[0])*l).transpose()\n", 681 | " Q, R = np.linalg.qr(A_null) # compute null space of (A-lam*I) to get eigenvector\n", 682 | " V2.append(Q[:, -1])\n", 683 | " \n", 684 | "V2 = np.array(V2).T\n", 685 | "V2" 686 | ] 687 | }, 688 | { 689 | "cell_type": "code", 690 | "execution_count": 70, 691 | "metadata": {}, 692 | "outputs": [ 693 | { 694 | "data": { 695 | "text/plain": [ 696 | "array([-0.02259485, -0.79943385, 0.51713865, 0.30489771])" 697 | ] 698 | }, 699 | "execution_count": 70, 700 | "metadata": {}, 701 | "output_type": "execute_result" 702 | } 703 | ], 704 | "source": [ 705 | "V2[:, np.argmax(lam)]" 706 | ] 707 | }, 708 | { 709 | "cell_type": "code", 710 | "execution_count": null, 711 | "metadata": {}, 712 | "outputs": [], 713 | "source": [] 714 | } 715 | ], 716 | "metadata": { 717 | "kernelspec": { 718 | "display_name": "eeg_env", 719 | "language": "python", 720 | "name": "eeg_env" 721 | }, 722 | "language_info": { 723 | "codemirror_mode": { 724 | "name": "ipython", 725 | "version": 3 726 | }, 727 | "file_extension": ".py", 728 | "mimetype": "text/x-python", 729 | "name": "python", 730 | "nbconvert_exporter": "python", 731 | "pygments_lexer": "ipython3", 732 | "version": "3.8.7" 733 | } 734 | }, 735 | "nbformat": 4, 736 | "nbformat_minor": 4 737 | } 738 | -------------------------------------------------------------------------------- /micropython/README.md: -------------------------------------------------------------------------------- 1 | ## MicroPython setup 2 | ### Overview 3 | The instructions below detail the steps necessary to get a working version of the [ESP32 port of MicroPython](https://docs.micropython.org/en/latest/esp32/quickref.html) on your local machine. This incldues the [`ulab`](https://github.com/v923z/micropython-ulab) extension for Micropython which is required for this project. It is offers very convenient and efficient `numpy`-like array manipulation. 4 | 5 | For more generic information related to setting up MicroPython for other ports, check out the [repo](https://github.com/micropython/micropython). 6 | 7 | ### Setup 8 | Install script tries to install `virutalenv` with `--user` flag which causes issues if you are installing from some virtual environments. Install it first manually to prevent this issue with `pip install virtualenv` (provided you are within an activated virtual environment if applicable). 9 | 10 | Pretty much all setup is contained within the script called `esp32-cmake.sh`. You may need to make it executable first using `chmod u+x esp32-cmake.sh`. Then, setup the build directory: 11 | 12 | ```bash 13 | mkdir ~/mpy-esp32 14 | export BUILD_DIR=~/mpy-esp32 15 | ``` 16 | You're free to perform the setup wherever. Just update `BUILD_DIR` accordingly. Now, we run the setup scrip using: 17 | 18 | ```bash 19 | ./esp32-cmake.sh 20 | ``` 21 | > Note that if you're using macOS with Apple Silicon, you may need to run the above prefixed with arch -x86_64 to ensure compatibility with later steps. 22 | 23 | Once complete, the following may be run for convenience: 24 | ```bash 25 | # export mpy cross compiler to path to use `mpy-cross` command 26 | export PATH=$PATH:~/mpy-esp32/micropython/mpy-cross/ 27 | ``` 28 | This will allow you to cross compile ordinary MicroPython scrips (`*.py` files) into binary `*.mpy` versions which are more efficient in terms of memory and speed. These can be built into the firmware by copying your MicroPython modules into `$BUILD_DIR/micropython/ports/esp32/modules`. 29 | 30 | #### Deployment 31 | Set your serial port. It will look something like this (but probably not the same): `export PORT=/dev/tty.usbserial-02U1W54L`. At least for Unix-based systems, you can list available serial ports using `ls /dev/tty.*`. 32 | 33 | Finally, to deploy to your board, run `make erase && make deploy`. 34 | 35 | ### Aditional notes 36 | #### Precision considerations 37 | An important consideration is the precision needed for your application. If your application requires numerous matrix multiplications or other operations where precision errors are at risk of propagating to an unacceptable degree, you may want to enable __double precision__ in the firmware build. This can be done by updating `MICROPY_FLOAT_IMPL` in `$BUILD_DIR/micropython/ports/esp32/mpconfigport.h` as follows: 38 | ```c 39 | // #define MICROPY_FLOAT_IMPL (MICROPY_FLOAT_IMPL_FLOAT) 40 | #define MICROPY_FLOAT_IMPL (MICROPY_FLOAT_IMPL_DOUBLE) 41 | ``` 42 | __NB__: In order for this to take effect, you'll have to rebuild the firmware image as detailed below. 43 | #### Rebuilding firmware 44 | A new firmware binary can be compiled by running the following within `$BUILD_DIR/micropython/ports/esp32/`: 45 | ```bash 46 | make clean && make all 47 | ``` 48 | Then, follow the deployment steps as mentioned above to flash the new image onto your target board. Note that this requires the ESP-IDF to be exported in your shell environment. If you get an error concerning this, you'll likely need to __export the IDF variables again__. See the section on the ESP-IDF in the troubleshooting section to rectify this. 49 | ### Uploading code 50 | 51 | Install the [`ampy`](https://learn.adafruit.com/micropython-basics-load-files-and-run-code/install-ampy) Python package in your same virtual environment. This allows you to upload, read, manipulate and delete files in non-volatile storage (flash) on the ESP32 over serial. 52 | 53 | Test the installation with `ampy -p /dev/tty.usbserial-02U1W54L ls` to list the files in NVS on the board (remember to replace your port accordingly). You should see something like `/boot.py`. You can see a list of other useful commands by typing `ampy --help`. 54 | 55 | As the MicroPython modules in this project *are not built into the firmware image by default*, you can use `ampy` to send them to your target board. Run 56 | ```bash 57 | ampy -p /dev/tty.usbserial-02U1W54L put lib/ 58 | ``` 59 | from within this directory to copy the `lib` package to the target (remember to __update your port__ accordingly). This will take a good few seconds to complete. Once done, you can access any of the `lib` modules as you would standard Python modules. For example, you can test an import from the `decoding` module as follows: 60 | ```python 61 | from lib.decoding import CCA 62 | ``` 63 | Note: you can also cross compile all modules in `lib` using `mpy-cross` or the `cross-compile.sh` script provided. If you send the compiled `.mpy` versions to the target board, you can still import and use them in exactly the same way (there is no difference from a usage point of view). 64 | ### Development 65 | #### Initial setup 66 | Some of the core functionality in the provided `lib` requires a few key variables to be set in a `.env` file that gets loaded in. You'll need to create a `.env` file in the `lib/` directory and provide the following: 67 | ```bash 68 | WIFI_SSID=your-wifi-ssid 69 | WIFI_PASSWORD=your-wifi-password 70 | 71 | # optional: MQTT server information 72 | MQTT_SERVER=mqtt-server-address 73 | MQTT_PORT=000 74 | MQTT_DEFUALT_TOPIC=mqtt-default-topic 75 | ``` 76 | You'll then need to actually upload the `.env` file to the target. You can either run the following from within a Jupyter Notebook (described below), 77 | ```ipython 78 | %sendtofile lib/.env --source lib/.env 79 | ``` 80 | or you can use `ampy`: 81 | ```bash 82 | ampy -p /dev/tty.usbserial-02U1W54L put .env lib/.env 83 | ``` 84 | Note that the first file argument is the source path and the second is the destination path (where on your target board the file will be uploaded to). 85 | 86 | #### Jupyter over serial 87 | 88 | An extremely useful feature of the MicroPython development platform is that it is compatible with Jupyter notebooks over serial. This has been made possible by projects like [this one](https://github.com/goatchurchprime/jupyter_micropython_kernel/) which contains all the details in getting setup. 89 | 90 | In summary, from within your virtual env, run: 91 | 92 | ```bash 93 | pip install jupyterlab 94 | pip install jupyter_micropython_kernel 95 | python -m jupyter_micropython_kernel.install 96 | ``` 97 | You can run `jupyter kernelspec list` to see where your Jupyter kernels are installed. You should see the `micropython` kernel listed there. To start a notebook, run 98 | ```bash 99 | jupyter lab 100 | ``` 101 | You can also run `jupyter notebook` for a more lightweight version. This will start the Jupyter server and should open a window in your default browser. Then, make sure the selected kernel is the `micropython` kernel you created earlier. In order to connect with the ESP32 over serial, run 102 | ```ipython 103 | %serialconnect to --port=/dev/tty.usbserial-02U1W54L --baud=115200 104 | ``` 105 | replacing your specific serial port as required. After a few moments, should see a response like 106 | ```text 107 | Connecting to --port=/dev/tty.usbserial-02U1W54L --baud=115200 108 | Ready. 109 | ``` 110 | Then, you're free to use your MicroPython board in an interactive notebook environment! You can test it with pretty much any standard python commands, including very convenient structures such as dictionary and list comprehensions. In order to test MicroPython-specific functionality, try running something like 111 | ```python 112 | from machine import Pin 113 | import time 114 | 115 | LED_PIN = 5 # replace as necessary 116 | 117 | led = Pin(LED_PIN, Pin.OUT) # define pin 0 as output 118 | 119 | led.value(1) # set LED on 120 | time.sleep(2) # wait 2 seconds 121 | led.value(0) # set LED off 122 | ``` 123 | Test that `ulab` was built into the compiled firmware image correctly by importing it: 124 | ```python 125 | import ulab 126 | ``` 127 | If no errors are shown, it worked! Here are some example of basic linear algebra functionality offered by `ulab` that give an idea of just how convenient and useful it is 128 | ```python 129 | from ulab import numpy as np 130 | 131 | # create an arbitrary positive definite, symmetric 3x3 matrix 132 | # A can be sliced like A[i0:i1, j0:j1] as with regular numpy 133 | A = np.array([[25, 15, -5], [15, 18, 0], [-5, 0, 11]]) 134 | 135 | A_sqrt = np.linalg.cholesky(A) # compute lower triangular square root or A using Cholesky decomp 136 | 137 | det_A = np.linalg.det(A) # compute determinant of A 138 | 139 | A_inv = np.linalg.inv(A) # compute determinant of A 140 | 141 | # compute Moore-Penrose pseudoinverse of A = (A^T.A)^-1.A^T 142 | A_pinv = np.dot(np.linalg.inv(np.dot(np.transpose(A), A)), np.transpose(A)) 143 | 144 | # many other utility functions such as argmax(), argsort(), convolve() 145 | ``` 146 | **Tip:** as with ordinary Jupyter Notebooks, you can use magic commands. Run `%lsmagic` in any cell to get a list of the magic commands available. These include commands to reset your device, send/retrieve files and others. 147 | 148 | ### Troubleshooting 149 | If you see complaints about not finding `idf.py`, it may be that the ESP-IDF has not been exported properly. To rectify this, navigate to `$BUILD_DIR/micropython/esp-idf`. Make sure the correct python environment has been activated and then run 150 | ```bash 151 | ./install.sh 152 | . ./export.sh 153 | ``` 154 | to install the IDF dependencies and export necessary variables to your shell environment. 155 | 156 | Your `Makefile` under `$BUILD_DIR/micropython/ports/esp32/` should begin with something like this: 157 | ```bash 158 | BOARD = GENERIC 159 | USER_C_MODULES = $(BUILD_DIR)/ulab/code/micropython.cmake 160 | # Makefile for MicroPython on ESP32. 161 | # 162 | # This is a simple, convenience wrapper around idf.py (which uses cmake). 163 | #... 164 | ``` 165 | -------------------------------------------------------------------------------- /micropython/boot.py: -------------------------------------------------------------------------------- 1 | # This file is executed on every boot (including wake-boot from deepsleep) 2 | # import esp 3 | # esp.osdebug(None) 4 | # import webrepl 5 | # webrepl.start() 6 | -------------------------------------------------------------------------------- /micropython/cross-compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | find lib -maxdepth 1 -type f -exec mpy-cross {} \; 3 | mv lib/*.mpy mpy-modules/ 4 | -------------------------------------------------------------------------------- /micropython/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JamesTev/EEG-decoding/c4054dd9d1eac857aedd487a34f177c97d95c0af/micropython/lib/__init__.py -------------------------------------------------------------------------------- /micropython/lib/aws/aws_ca.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIE0zCCA7ugAwIBAgIQGNrRniZ96LtKIVjNzGs7SjANBgkqhkiG9w0BAQUFADCB 3 | yjELMAkGA1UEBhMCVVMxFzAVBgNVBAoTDlZlcmlTaWduLCBJbmMuMR8wHQYDVQQL 4 | ExZWZXJpU2lnbiBUcnVzdCBOZXR3b3JrMTowOAYDVQQLEzEoYykgMjAwNiBWZXJp 5 | U2lnbiwgSW5jLiAtIEZvciBhdXRob3JpemVkIHVzZSBvbmx5MUUwQwYDVQQDEzxW 6 | ZXJpU2lnbiBDbGFzcyAzIFB1YmxpYyBQcmltYXJ5IENlcnRpZmljYXRpb24gQXV0 7 | aG9yaXR5IC0gRzUwHhcNMDYxMTA4MDAwMDAwWhcNMzYwNzE2MjM1OTU5WjCByjEL 8 | MAkGA1UEBhMCVVMxFzAVBgNVBAoTDlZlcmlTaWduLCBJbmMuMR8wHQYDVQQLExZW 9 | ZXJpU2lnbiBUcnVzdCBOZXR3b3JrMTowOAYDVQQLEzEoYykgMjAwNiBWZXJpU2ln 10 | biwgSW5jLiAtIEZvciBhdXRob3JpemVkIHVzZSBvbmx5MUUwQwYDVQQDEzxWZXJp 11 | U2lnbiBDbGFzcyAzIFB1YmxpYyBQcmltYXJ5IENlcnRpZmljYXRpb24gQXV0aG9y 12 | aXR5IC0gRzUwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCvJAgIKXo1 13 | nmAMqudLO07cfLw8RRy7K+D+KQL5VwijZIUVJ/XxrcgxiV0i6CqqpkKzj/i5Vbex 14 | t0uz/o9+B1fs70PbZmIVYc9gDaTY3vjgw2IIPVQT60nKWVSFJuUrjxuf6/WhkcIz 15 | SdhDY2pSS9KP6HBRTdGJaXvHcPaz3BJ023tdS1bTlr8Vd6Gw9KIl8q8ckmcY5fQG 16 | BO+QueQA5N06tRn/Arr0PO7gi+s3i+z016zy9vA9r911kTMZHRxAy3QkGSGT2RT+ 17 | rCpSx4/VBEnkjWNHiDxpg8v+R70rfk/Fla4OndTRQ8Bnc+MUCH7lP59zuDMKz10/ 18 | NIeWiu5T6CUVAgMBAAGjgbIwga8wDwYDVR0TAQH/BAUwAwEB/zAOBgNVHQ8BAf8E 19 | BAMCAQYwbQYIKwYBBQUHAQwEYTBfoV2gWzBZMFcwVRYJaW1hZ2UvZ2lmMCEwHzAH 20 | BgUrDgMCGgQUj+XTGoasjY5rw8+AatRIGCx7GS4wJRYjaHR0cDovL2xvZ28udmVy 21 | aXNpZ24uY29tL3ZzbG9nby5naWYwHQYDVR0OBBYEFH/TZafC3ey78DAJ80M5+gKv 22 | MzEzMA0GCSqGSIb3DQEBBQUAA4IBAQCTJEowX2LP2BqYLz3q3JktvXf2pXkiOOzE 23 | p6B4Eq1iDkVwZMXnl2YtmAl+X6/WzChl8gGqCBpH3vn5fJJaCGkgDdk+bW48DW7Y 24 | 5gaRQBi5+MHt39tBquCWIMnNZBU4gcmU7qKEKQsTb47bDN0lAtukixlE0kF6BWlK 25 | WE9gyn6CagsCqiUXObXbf+eEZSqVir2G3l6BFoMtEMze/aiCKm0oHw0LxOXnGiYZ 26 | 4fQRbxC1lfznQgUy286dUV4otp6F01vvpX1FQHKOtw5rDgb7MzVIcbidJ4vEZV8N 27 | hnacRHr2lVz2XTIIM6RUthg/aFzyQkqFOFSDX9HoLPKsEdao7WNq 28 | -----END CERTIFICATE----- -------------------------------------------------------------------------------- /micropython/lib/aws/dab0ac2b5c-certificate.pem.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDWjCCAkKgAwIBAgIVANyZ0m4qukp3J5mmqUqhFU1/iZwnMA0GCSqGSIb3DQEB 3 | CwUAME0xSzBJBgNVBAsMQkFtYXpvbiBXZWIgU2VydmljZXMgTz1BbWF6b24uY29t 4 | IEluYy4gTD1TZWF0dGxlIFNUPVdhc2hpbmd0b24gQz1VUzAeFw0yMTA2MTUxNDM2 5 | MzRaFw00OTEyMzEyMzU5NTlaMB4xHDAaBgNVBAMME0FXUyBJb1QgQ2VydGlmaWNh 6 | dGUwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDE61mFvh+LCdHiL10o 7 | vRqEOV34fP+8RXFDrMw0+HAMHm+LOnoI7QnDVsPj3B3jIBQOTIHzO05dYFxy+13U 8 | 5QV2/op9QwrqFGYhAASnbjmhqrBfVwuHFl6b+Rnf+11kUd4Wv2iMSAjO9jJ0gWZe 9 | 4Eq3ClYfJvXELOo33JM31yI+A+xhDXgDv0Tr9RVtA9WWebIURRuQ0VXH94GHcVJM 10 | ugCreWV77psIHLXnVMrF4j2sEiokU8vU6ji/nNjcloOjIZyzsyvQPKVlC7UeApxH 11 | vUckZs/xbIdxbCuyWTJNtBI6kdsmpUEed8wEWbTwbXCY+YGjIuEAa1YG+sQwEF8W 12 | XLiLAgMBAAGjYDBeMB8GA1UdIwQYMBaAFIKI5lqR2kZEF6IgOgLIYWlUquyBMB0G 13 | A1UdDgQWBBQ5FGvxV8OYrVQJEWKG8NjLrlmHYjAMBgNVHRMBAf8EAjAAMA4GA1Ud 14 | DwEB/wQEAwIHgDANBgkqhkiG9w0BAQsFAAOCAQEApdHAqUktXIlnHp3+qhe/PrfM 15 | u9zcqWILHgyouhBoKcntiEz0xc3yaLRVyIcbcLnbVWfNR0yyXoGMOtjNrCHTJ73R 16 | zBYM/A7Q6P3a/bewRZZwgD7hbJSw/X5VXfIhfbqZ0fpok9w1z8Lzas9yG580GCfM 17 | kgTo6PifVNyWUCuJyBl5OJldnb5KwP61pert47X+xLjNdiu3srl9aA5Q41YV1JMO 18 | +t3KWlNdfFK/6w6yUUVQGEKA/MzB+JIhy17Rz2fjaaEpH8tyqO/PFDbrrGMluDc7 19 | DvASjTOhilMNE0qvbRS8wBaP4coQ8BwJLB7Fs4SOg/flaRIy86gdOitn8d2zKA== 20 | -----END CERTIFICATE----- 21 | -------------------------------------------------------------------------------- /micropython/lib/aws/dab0ac2b5c-private.pem.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEpAIBAAKCAQEAxOtZhb4fiwnR4i9dKL0ahDld+Hz/vEVxQ6zMNPhwDB5vizp6 3 | CO0Jw1bD49wd4yAUDkyB8ztOXWBccvtd1OUFdv6KfUMK6hRmIQAEp245oaqwX1cL 4 | hxZem/kZ3/tdZFHeFr9ojEgIzvYydIFmXuBKtwpWHyb1xCzqN9yTN9ciPgPsYQ14 5 | A79E6/UVbQPVlnmyFEUbkNFVx/eBh3FSTLoAq3lle+6bCBy151TKxeI9rBIqJFPL 6 | 1Oo4v5zY3JaDoyGcs7Mr0DylZQu1HgKcR71HJGbP8WyHcWwrslkyTbQSOpHbJqVB 7 | HnfMBFm08G1wmPmBoyLhAGtWBvrEMBBfFly4iwIDAQABAoIBAQCTG3qaVij8Vo6r 8 | 2VRP/c1UYALago8N2BbARtOa8snJ0+bibQIDrwjvG99lVugg57Lz56XgzjpBuZ32 9 | 69/yDlFhztAoua/qpOiS2I+hgM+e/YObBcz/0u9Et/fjgsYHDr3J4p44xguGiRey 10 | P4T5dbd7PEaQKSvKrP8gUjDMs3PKPRVT+KufEnhItjGD38roCnFPEfAXaww+APK5 11 | ef+CeKys2hnFzpdmKCW7wtefctBW5comIqgit0UAZBaow0mDn8Bpk9kBHs21Z4Gf 12 | fdK8qasdyWZ2pdKF+wnhkjhA0xyymNCibp0JVMgHmCSwCAbPPHFT1SPFYl+fR+0g 13 | I1f/5eIJAoGBAPUkAXySd1lYQNEfP5qV6xTGIhDs8+rA+n9qf+LpVkG7jRrNOgry 14 | ADFkA7Poo2htV3+v7Ws9O+m6YDrEPUQV93FdLq7JUEDTnPcpLFjLhxfWse+jqiyq 15 | NlMA1JkIgprz4Q8AI8fdMkd8SxF/o4kSOmq/I369bge0jC3s5arFYYcvAoGBAM2k 16 | fqM89d9LHVRHERXiIjvk3jqWQOIhTzZw+kb2IFWlYakpzSXbql9ApqRZv2bEDy9a 17 | Ftr2MMU0fVRpJTQzNLiQCfZrqJNKyhxueUW0di/vEQ32QgQ1yZViYOmh6i7tBKQf 18 | Rd9irKu9ou7MTkd4LBpjJTzXV5Jh+XqrWCHMCw1lAoGAZJqyQX2anYAqmahJYhfE 19 | snHd6TRVpjIsm74JXgaE79UYw7cV6wgGLD7vhhYjeefl0PFSboBi5jkm5PdEbXzW 20 | YB/t2K2l0fUGpiivmoFsn8vsmnyAinQQd2AP3/4uG4mbA5z7EjslN1E7PaM0ajyg 21 | 3M40sv7q+R6PqQlf+AWpUOECgYAi+Vj4VkPSyCmizbtgJ2d/HDuXZzp9uwCH18sZ 22 | Cn8RUeE8LfnP+H4PWSV+Si09Qxb6DK5aawh8YaL3GjLTn+nVgdnP6RtwgYXCNB+3 23 | 9Sssl9Ikt78ejC9Okr68vWbguDrT1uXBSjznsIbnaiPeD9ggXMBWiqFAKh2N+Vox 24 | 1UwD6QKBgQDQKMkJy6ddj8zZui5czzKHqhUwFlIuVGZ88C9uH5LHQu8Tku7YRlei 25 | lOrwmM/TtT59dN/tymbmi1q8Eur71XEM/6IF27tIP1jH8C0jLooZbg9fR1AZ8p0m 26 | gEHK4Wj+LRliwzpqGfkXX+WSZDhSodLkB91NCvAmlN2PhmioCiZXzg== 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /micropython/lib/aws/dab0ac2b5c-public.pem.key: -------------------------------------------------------------------------------- 1 | -----BEGIN PUBLIC KEY----- 2 | MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAxOtZhb4fiwnR4i9dKL0a 3 | hDld+Hz/vEVxQ6zMNPhwDB5vizp6CO0Jw1bD49wd4yAUDkyB8ztOXWBccvtd1OUF 4 | dv6KfUMK6hRmIQAEp245oaqwX1cLhxZem/kZ3/tdZFHeFr9ojEgIzvYydIFmXuBK 5 | twpWHyb1xCzqN9yTN9ciPgPsYQ14A79E6/UVbQPVlnmyFEUbkNFVx/eBh3FSTLoA 6 | q3lle+6bCBy151TKxeI9rBIqJFPL1Oo4v5zY3JaDoyGcs7Mr0DylZQu1HgKcR71H 7 | JGbP8WyHcWwrslkyTbQSOpHbJqVBHnfMBFm08G1wmPmBoyLhAGtWBvrEMBBfFly4 8 | iwIDAQAB 9 | -----END PUBLIC KEY----- 10 | -------------------------------------------------------------------------------- /micropython/lib/computation.py: -------------------------------------------------------------------------------- 1 | from ulab import numpy as np 2 | import urandom 3 | 4 | def solve_gen_eig_prob(A, B, eps=1e-5): 5 | """ 6 | Solves the generalised eigenvalue problem of the form: 7 | Aw = \lambda*Bw 8 | 9 | Note: can be validated against `scipy.linalg.eig(A, b=B)` 10 | 11 | Ref: 12 | 'Eigenvalue and Generalized Eigenvalue Problems: Tutorial (2019)' 13 | Benyamin Ghojogh and Fakhri Karray and Mark Crowley 14 | arXiv 1903.11240 15 | 16 | """ 17 | Lam_b, Phi_b = np.linalg.eig(B) # eig decomp of B alone 18 | Lam_b = np.eye(len(Lam_b)) * Lam_b # convert to diagonal matrix of eig vals 19 | 20 | Lam_b_sq = replace_nan(Lam_b ** 0.5) + np.eye(len(Lam_b)) * eps 21 | Phi_b_hat = np.dot(Phi_b, np.linalg.inv(Lam_b_sq)) 22 | A_hat = np.dot(np.dot(Phi_b_hat.transpose(), A), Phi_b_hat) 23 | 24 | try: 25 | Lam_a, Phi_a = np.linalg.eig(A_hat) 26 | except ValueError: 27 | # if `ulab` raises a "input matrix asymmetric" error in analytical approach, 28 | # we have to estimate eigen-pair iteratively 29 | Lam_a, Phi_a = solve_eig_qr(A_hat) 30 | 31 | Lam_a = np.eye(len(Lam_a)) * Lam_a 32 | 33 | Lam = Lam_a 34 | Phi = np.dot(Phi_b_hat, Phi_a) 35 | 36 | return np.diag(Lam), Phi 37 | 38 | 39 | def solve_eig_qr(A, iterations=30): 40 | 41 | """ 42 | Use the QR iteration algorithm to iteratively solve for the eigenvectors and eigenvalues 43 | of a matrix A. Note: only guaranteed to recover exactly for symmetric matrices 44 | with real eigenvalues. May work partially for asymmetric matrices (no complex support yet). 45 | 46 | Returns: 47 | `lam`: vector of eigenvalues 48 | `Q_bar`: matrix of eigenvectors (columns) 49 | """ 50 | 51 | Ak = A 52 | Q_bar = np.eye(len(Ak)) 53 | 54 | for _ in range(iterations): 55 | Qk, Rk = np.linalg.qr(Ak) 56 | Ak = np.dot(Rk, Qk) 57 | Q_bar = np.dot(Q_bar, Qk) 58 | 59 | lam = np.diag(Ak) 60 | return lam, Q_bar 61 | 62 | 63 | def power_iteration(A, iterations): 64 | """ 65 | Iterative algo. to find the eigenvector of a matrix A corresponding to the largest 66 | eigenvalue. 67 | 68 | TODO: Establish some measure or heuristic of min number of iterations required 69 | """ 70 | # choose random initial vector to reduce risk of choosing one orthogonal to 71 | # target eigen vector 72 | b_k = np.array([urandom.random() for i in range(len(A))]) 73 | 74 | for _ in range(iterations): 75 | b_k1 = np.dot(A, b_k) 76 | b_k1_norm = np.linalg.norm(b_k1) 77 | # re normalize the vector 78 | b_k = b_k1 / b_k1_norm 79 | 80 | return b_k1_norm, b_k 81 | 82 | 83 | def max_eig(A, iterations, numeric_method="qr"): 84 | """ 85 | Function to return the largest eigenvalue of a matrix and its corresponding eigenvector. 86 | 87 | A must be square but need not be symmetric. Tries to first use uLab `np.linalg.eig` 88 | that is better optimised but requires a symmetric matrix. Failing this, power iteration 89 | algorithm is used. 90 | """ 91 | try: 92 | lam, V = np.linalg.eig(A) 93 | v = V[:, np.argmax(lam)] 94 | except ValueError: 95 | if numeric_method == "power_iteration": 96 | lam, v = power_iteration(A, iterations) 97 | else: 98 | if numeric_method != "qr": 99 | print("Unknown `numeric_method` arg: defaulting to QR solver") 100 | lam, v = solve_eig_qr(A, iterations) 101 | lam = lam[0] # only need first eigen val (largest returned first) 102 | v = v[:, 0] # only first eig vector 103 | 104 | return lam, v 105 | 106 | 107 | def resample(X, factor): 108 | """ 109 | Perform downsampling of signal `X` by an integer `factor`. 110 | """ 111 | idx_rs = np.arange(0, len(X) - 1, factor) 112 | return X[idx_rs] 113 | 114 | 115 | def standardise(X): 116 | axis = np.argmax(X.shape) 117 | minor_shape = np.min(X.shape) 118 | mu = np.mean(X, axis=axis).reshape((minor_shape, 1)) 119 | sigma = np.std(X, axis=axis).reshape((minor_shape, 1)) 120 | return (X - mu) / sigma 121 | 122 | 123 | def cov(X, Y, biased=False): 124 | assert ( 125 | X.shape == Y.shape and len(X.shape) == 1 126 | ), "Expected data vectors of equal length" 127 | assert len(X) > 1, "At least 2 data points are required" 128 | 129 | X = X - np.mean(X) 130 | Y = Y - np.mean(Y) 131 | denom = len(X) if biased else len(X) - 1 132 | 133 | return (np.sum(X * Y)) / denom 134 | 135 | 136 | def corr(X, Y): 137 | assert ( 138 | X.shape == Y.shape and len(X.shape) == 1 139 | ), "Expected data vectors of equal length" 140 | assert len(X) > 1, "At least 2 data points are required" 141 | 142 | return cov(X, Y, biased=True) / (np.std(X) * np.std(Y)) 143 | 144 | 145 | def replace_nan(A, rep=0): 146 | return np.where(np.isfinite(A), A, rep) 147 | 148 | def col_concat(*mats): 149 | """" 150 | Concatenate a variable number of matrices along their 151 | column axis (axis=1 using `numpy` convention). 152 | """ 153 | cols = sum([mat.shape[1] for mat in mats]) 154 | rows = mats[0].shape[0] 155 | out = np.zeros((rows, cols)) 156 | j = 0 157 | for mat in mats: 158 | mat_cols = mat.shape[1] 159 | out[:, j:j+mat_cols] = mat 160 | j += mat_cols 161 | 162 | return out 163 | 164 | def zeros_like(A): 165 | return np.zeros(A.shape) 166 | 167 | def block_diag(X, Y, reverse=False): 168 | if not reverse: 169 | X = np.concatenate((X, zeros_like(X)), axis=1) 170 | Y = np.concatenate((zeros_like(Y), Y), axis=1) 171 | else: 172 | X = np.concatenate((zeros_like(X), X), axis=1) 173 | Y = np.concatenate((Y, zeros_like(Y)), axis=1) 174 | return np.concatenate((X, Y), axis=0) 175 | 176 | def sign(x): 177 | """ 178 | Return the sign of a numerical variable. 179 | """ 180 | x+1 # arb operation to raise an error if non-numeric arg given. 181 | return 1 if x >=0 else -1 -------------------------------------------------------------------------------- /micropython/lib/config.py: -------------------------------------------------------------------------------- 1 | import utime 2 | 3 | BASE_CLK_FREQ = 240000000 # 240 MHz for ESP32 4 | 5 | ADC_SAMPLE_FREQ = 256 # sample freq in Hz 6 | 7 | RECORDING_LEN_SEC = 4 8 | 9 | OVERLAP = 0.8 10 | 11 | DOWNSAMPLED_FREQ = 64 # 64 Hz downsampled to ensure nyquist condition 12 | 13 | PREPROCESSING = True # if true, LP filter and downsample 14 | 15 | STIM_FREQS = [7, 10, 12] # stimulus freqs. in Hz 16 | 17 | DEFAULT_LOG_SESSION = "test-{0}".format(utime.ticks_ms()) 18 | 19 | MODE = "log" 20 | 21 | HTTP_LOG_URL = "http://james-tev.local:5000/" 22 | -------------------------------------------------------------------------------- /micropython/lib/decoding.py: -------------------------------------------------------------------------------- 1 | from ulab import numpy as np 2 | import gc 3 | 4 | from .computation import solve_eig_qr, solve_gen_eig_prob, standardise, col_concat, corr, block_diag, sign 5 | 6 | class SingleChannelMsetCCA(): 7 | """ 8 | Multiset CCA algorithm for SSVEP decoding. 9 | Computes optimised reference signal set based on historical observations 10 | and uses ordinary CCA for final correlation computation given a new test 11 | signal. 12 | Note: this is a 1 channel implementation (Nc=1) 13 | """ 14 | def __init__(self): 15 | self.Ns, self.Nt = None, None 16 | self.Y = None 17 | 18 | def fit(self, X, compress_ref=True): 19 | """ 20 | Expects a training matrix X of shape Nt x Ns. If `compress_ref=True`, the `Nt` components in optimised 21 | reference signal Y will be averaged to form a single reference vector. This can be used for memory 22 | optimisation but will likely degrade performance slightly. 23 | """ 24 | if X.shape[0] > X.shape[1]: 25 | print("Warning: received more trials than samples. This is unusual behaviour: check X") 26 | 27 | R = np.dot(X, X.transpose()) # inter trial covariance matrix 28 | S = np.eye(len(R))*np.diag(R) # intra-trial diag covariance matrix 29 | lam, V = solve_gen_eig_prob((R-S), S) # solve generalised eig problem 30 | w = V[:, np.argmax(lam)] # find eigenvector corresp to largest eigenvalue 31 | Y = np.array([x*w[i] for i, x in enumerate(X)]) # store optimised reference vector Nt x Ns self.Y = Y 32 | 33 | if compress_ref: 34 | self.Y = np.mean(Y, axis=0).reshape((1, max(Y.shape))) # this will average Nt components in Y: Nc x Nt -> 1 x Nt 35 | 36 | def compute_corr(self, X_test): 37 | if not self.is_calibrated: 38 | raise ValueError("Reference matrix Y must be computed using fit before computing corr") 39 | if len(X_test.shape) == 1: 40 | X_test = X_test.reshape((1, len(X_test))) 41 | return CCA.cca_eig(X_test, self.Y)[0] # use ordinary CCA with optimised ref. Y 42 | 43 | @property 44 | def is_calibrated(self): 45 | return self.Y is not None 46 | 47 | class SingleChannelGCCA(): 48 | """ 49 | Generalised canonical component analysis for Nc=1. 50 | Expects the target frequency at `f_ssvep`. `fs` is the sampling rate used and `Nh` the number of harmonics for the harmonic r 51 | Ref: 'Improving SSVEP Identification Accuracy via Generalized Canonical Correlation Analysis' Sun, Chen et al 52 | """ 53 | def __init__(self, f_ssvep, fs, Nh=1, name=None): 54 | self.Ns, self.Nt = None, None 55 | self.Nh = Nh 56 | self.w = None 57 | self.X_bar = None 58 | self.fs = fs 59 | self.f_ssvep = f_ssvep 60 | self.name = name or "gcca_{0}hz".format(f_ssvep) 61 | 62 | def fit(self, X): 63 | """ 64 | Fit against training data. 65 | X should be a matrix of dim (Nt x Ns) 66 | """ 67 | self.Nt, self.Ns = X.shape 68 | 69 | # template signal 70 | X_bar = np.mean(X, axis=0).reshape((1, self.Ns)) 71 | Y = harmonic_reference(self.f_ssvep, self.fs, self.Ns) 72 | 73 | # form concatenated matrices (vectors for Nc=1) 74 | X_c = X.reshape((1, self.Ns*self.Nt)) 75 | 76 | X_bar_c = col_concat(*[X_bar for i in range(self.Nt)]) 77 | X_bar_c = X_bar_c.reshape((1, self.Ns*self.Nt)) 78 | 79 | Y_c = col_concat(*[Y for i in range(self.Nt)]) 80 | 81 | X_comb = col_concat(X_c.T, X_bar_c.T, Y_c.T).T 82 | 83 | D1 = np.dot(X_c, X_c.T) 84 | D2 = np.dot(X_bar_c, X_bar_c.T) 85 | D3 = np.dot(Y_c, Y_c.T) 86 | 87 | D = block_diag(block_diag(D1, D2), D3) 88 | 89 | lam, W_eig = solve_gen_eig_prob(np.dot(X_comb, X_comb.T), D) 90 | 91 | self.w = W_eig[:, np.argmax(lam)] # optimal spatial filter vector with dim (2*Nc + 2*Nh) 92 | self.X_bar = X_bar 93 | 94 | def compute_corr(self, X_test): 95 | """ 96 | Compute output correlation for a test observation with dim. (1 x Ns) 97 | """ 98 | if not self.is_calibrated: 99 | raise ValueError("call .fit(X_train) before performing classification.") 100 | 101 | if len(X_test.shape) == 1: 102 | X_test = X_test.reshape((len(X_test), 1)) 103 | else: 104 | X_test = X_test.T 105 | 106 | w_X = self.w[0:1] 107 | w_X_bar = self.w[1:2] # second weight correspond to Nc (Nc=1) template channels 108 | w_Y = self.w[2:] # final 2*Nh weights correspond to ref sinusoids with harmonics 109 | 110 | # regenerate these instead of storing from the `fit` function since 111 | # computationally cheap to generate but expensive to store in memory 112 | Y = harmonic_reference(self.f_ssvep, self.fs, self.Ns) 113 | 114 | X_test_image = np.dot(X_test, w_X) 115 | rho1 = corr(X_test_image, np.dot(self.X_bar.T, w_X_bar)) 116 | rho2 = corr(X_test_image, np.dot(Y.T, w_Y)) 117 | 118 | return sum([sign(rho_i)*rho_i**2 for rho_i in [rho1, rho2]])/2 119 | 120 | @property 121 | def is_calibrated(self): 122 | return self.w is not None 123 | class CCA: 124 | def __init__(self, f_ssvep, fs, Nh=1): 125 | self.Nh = Nh 126 | self.fs = fs 127 | self.f_ssvep = f_ssvep 128 | 129 | def compute_corr(self, X_test): 130 | Cxx = np.dot( 131 | X_test, X_test.transpose() 132 | ) # precompute data auto correlation matrix 133 | Y = harmonic_reference( 134 | self.f_ssvep, self.fs, np.max(X_test.shape), Nh=self.Nh, standardise_out=False 135 | ) 136 | return self.cca_eig( 137 | X_test, Y, Cxx=Cxx 138 | )[0] # canonical variable matrices. Xc = X^T.W_x 139 | 140 | @staticmethod 141 | def cca_eig(X, Y, Cxx=None, eps=1e-6): 142 | if Cxx is None: 143 | Cxx = np.dot(X, X.transpose()) # auto correlation matrix 144 | Cyy = np.dot(Y, Y.transpose()) 145 | Cxy = np.dot(X, Y.transpose()) # cross correlation matrix 146 | Cyx = np.dot(Y, X.transpose()) # same as Cxy.T 147 | 148 | M1 = np.dot(np.linalg.inv(Cxx + eps), Cxy) # intermediate result 149 | M2 = np.dot(np.linalg.inv(Cyy + eps), Cyx) 150 | 151 | lam, _ = solve_eig_qr(np.dot(M1, M2), 20) 152 | return np.sqrt(lam) 153 | 154 | 155 | def harmonic_reference(f0, fs, Ns, Nh=1, standardise_out=False): 156 | 157 | """ 158 | Generate reference signals for canonical correlation analysis (CCA) 159 | -based steady-state visual evoked potentials (SSVEPs) detection [1, 2]. 160 | function [ y_ref ] = cca_reference(listFreq, fs, Ns, Nh) 161 | Input: 162 | f0 : stimulus frequency 163 | fs : Sampling frequency 164 | Ns : # of samples in trial 165 | Nh : # of harmonics 166 | Output: 167 | X : Generated reference signals with shape (Nf, Ns, 2*Nh) 168 | """ 169 | X = np.zeros((Nh * 2, Ns)) 170 | 171 | for harm_i in range(Nh): 172 | # Sin and Cos 173 | X[2 * harm_i, :] = np.sin( 174 | np.arange(1, Ns + 1) * (1 / fs) * 2 * np.pi * (harm_i + 1) * f0 175 | ) 176 | gc.collect() 177 | X[2 * harm_i + 1, :] = np.cos( 178 | np.arange(1, Ns + 1) * (1 / fs) * 2 * np.pi * (harm_i + 1) * f0 179 | ) 180 | gc.collect() 181 | 182 | # print(micropython.mem_info(1)) 183 | if standardise_out: # zero mean, unit std. dev 184 | return standardise(X) 185 | return X 186 | 187 | class DecoderSSVEP(): 188 | 189 | decoding_algos = ['CCA', 'MsetCCA', 'GCCA'] 190 | 191 | def __init__(self, stim_freqs, fs, algo): 192 | 193 | self.stim_freqs = stim_freqs 194 | self.fs = fs 195 | self.algo = algo 196 | 197 | self.decoder_stack = {} 198 | 199 | for f in self.stim_freqs: 200 | if algo == 'CCA': 201 | decoder_f = CCA(f, self.fs, Nh=1) 202 | elif algo == 'MsetCCA': 203 | decoder_f = SingleChannelMsetCCA() 204 | elif algo == 'GCCA': 205 | decoder_f = SingleChannelGCCA(f, self.fs, Nh=1) 206 | else: 207 | raise ValueError("Invalid algorithm. Must be one of {}".format(self.decoding_algos)) 208 | 209 | self.decoder_stack[f] = decoder_f 210 | 211 | @property 212 | def requires_calibration(self): 213 | return self.algo in ['MsetCCA', 'GCCA'] 214 | 215 | @property 216 | def is_calibrated(self): 217 | return all([d.is_calibrated for d in self.decoder_stack.values()]) 218 | 219 | def calibrate(self, calibration_data_map): 220 | 221 | if not self.requires_calibration: 222 | print("Warning: trying to fit data with an algorithm that doesn't require calibration") 223 | return 224 | 225 | for freq, cal_data in calibration_data_map.items(): 226 | if freq not in self.stim_freqs: 227 | raise ValueError("Invalid stimulus frequency supplied") 228 | self.decoder_stack[freq].fit(cal_data) 229 | 230 | def classify(self, X_test): 231 | result = {} 232 | for f, decoder_f in self.decoder_stack.items(): 233 | if self.requires_calibration and not decoder_f.is_calibrated: 234 | print("Warning: decoder has not been calibrated for {}Hz stimulus frequency".format(f)) 235 | result[f] = np.nan 236 | else: 237 | result[f] = decoder_f.compute_corr(X_test) 238 | return result -------------------------------------------------------------------------------- /micropython/lib/diagnostics.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import micropython 3 | 4 | 5 | def print_memory_diagnostics(): 6 | # see https://docs.micropython.org/en/latest/reference/constrained.html 7 | gc.enable() 8 | 9 | gc.collect() 10 | micropython.mem_info() 11 | print("-----------------------------") 12 | print("Initial free: {} allocated: {}".format(gc.mem_free(), gc.mem_alloc())) 13 | 14 | def func(): 15 | # dummy memory assignment 16 | import urandom 17 | 18 | x = [urandom.random() for i in range(100)] 19 | return 20 | 21 | gc.collect() 22 | print("Func definition: {} allocated: {}".format(gc.mem_free(), gc.mem_alloc())) 23 | func() 24 | print("Func run free: {} allocated: {}".format(gc.mem_free(), gc.mem_alloc())) 25 | gc.collect() 26 | print( 27 | "Garbage collect free: {} allocated: {}".format(gc.mem_free(), gc.mem_alloc()) 28 | ) 29 | print("-----------------------------") 30 | micropython.mem_info(1) 31 | -------------------------------------------------------------------------------- /micropython/lib/logging.py: -------------------------------------------------------------------------------- 1 | from scheduling import ScheduledFunc 2 | import ujson as json 3 | from utils import Enum 4 | import config 5 | 6 | logger_types = Enum(["SERIAL", "MQTT", "HTTP"]) 7 | 8 | 9 | class BaseLogger(ScheduledFunc): 10 | def __init__(self, period_sec, decoded_ref, raw_data_ref, timer_num=1): 11 | super().__init__(timer_num, 1 / period_sec) 12 | self.raw_data = raw_data_ref 13 | self.decoded_data = decoded_ref 14 | 15 | def log(self, *args): 16 | print(self.raw_data) 17 | 18 | def start(self): 19 | self.tim.init(freq=self.freq, callback=self.log) 20 | 21 | 22 | class AbstractWebLogger(BaseLogger): 23 | def __init__( 24 | self, 25 | period_sec, 26 | decoded_ref, 27 | raw_data_ref, 28 | timer_num=1, 29 | server=None, 30 | send_raw=False, 31 | session_id=None, 32 | ): 33 | super().__init__(period_sec, decoded_ref, raw_data_ref, timer_num=timer_num) 34 | 35 | if period_sec < 1: 36 | print("Warning: async web logging at > 1Hz will be unreliable.") 37 | 38 | self.send_raw = send_raw # whether or not to send full raw data 39 | self.server = server 40 | self.session_id = session_id # used to identify logging sessions 41 | 42 | def set_session_id(self, id): 43 | self.session_id = id 44 | 45 | def _prepare_payload(self, payload_id=None): 46 | from lib.networking import pack_payload 47 | 48 | raw_data = self.raw_data if self.send_raw else [] 49 | 50 | return pack_payload( 51 | raw_data, self.decoded_data, user_id=payload_id, session_id=self.session_id 52 | ) 53 | 54 | 55 | class MQTTLogger(AbstractWebLogger): 56 | def __init__( 57 | self, 58 | period_sec, 59 | decoded_ref, 60 | raw_data_ref, 61 | timer_num=1, 62 | server=None, 63 | send_raw=False, 64 | session_id=None, 65 | port=None, 66 | qos=1, 67 | topic=None, 68 | ): 69 | super().__init__( 70 | period_sec, 71 | decoded_ref, 72 | raw_data_ref, 73 | timer_num=timer_num, 74 | server=server, 75 | send_raw=send_raw, 76 | session_id=session_id, 77 | ) 78 | 79 | from lib.networking import setup_mqtt_client, get_default_topic 80 | 81 | self.client = setup_mqtt_client(server=server, port=port) 82 | self.topic = topic or get_default_topic() 83 | self.qos = qos 84 | self.establish_connection() 85 | 86 | def establish_connection(self): 87 | msg = "ESP32 client {0} connected".format(self.client.client_id) 88 | 89 | self.client.connect() 90 | self.client.publish( 91 | topic=self.mqtt_topic, msg=json.dumps({"message": msg}), qos=self.qos 92 | ) 93 | 94 | def _prepare_payload(self): 95 | return super()._prepare_payload(payload_id=self.client.client_id) 96 | 97 | def get_client(self): 98 | return self.client 99 | 100 | def log(self, *args): 101 | payload = self._prepare_payload() 102 | self.client.publish(topic=self.topic, msg=payload, qos=self.qos) 103 | 104 | 105 | class HTTPLogger(AbstractWebLogger): 106 | def __init__( 107 | self, 108 | period_sec, 109 | decoded_ref, 110 | raw_data_ref, 111 | timer_num=1, 112 | server=None, 113 | send_raw=False, 114 | session_id=None, 115 | ): 116 | super().__init__( 117 | period_sec, 118 | decoded_ref, 119 | raw_data_ref, 120 | timer_num=timer_num, 121 | server=server, 122 | send_raw=send_raw, 123 | session_id=session_id, 124 | ) 125 | self.server = server or config.HTTP_LOG_URL 126 | 127 | def log(self, *args): 128 | from lib.requests import MicroWebCli as requests 129 | 130 | payload = self._prepare_payload() 131 | requests.JSONRequest(self.server, payload) 132 | -------------------------------------------------------------------------------- /micropython/lib/networking.py: -------------------------------------------------------------------------------- 1 | from lib.umqtt import MQTTClient 2 | from lib.utils import load_env_vars 3 | 4 | import ujson as json 5 | import urandom 6 | 7 | 8 | def rand_str(l=10): 9 | return "".join([chr(urandom.randint(80, 120)) for i in range(l)]) 10 | 11 | 12 | def default_sub_cb(topic, msg): 13 | print(topic, msg) 14 | 15 | 16 | env_vars = load_env_vars("lib/.env") 17 | 18 | 19 | def setup_mqtt_client(client_id=None, server=None, port=None, callback=None): 20 | 21 | server = server or env_vars.get("MQTT_SERVER") 22 | port = port or env_vars.get("MQTT_PORT") 23 | 24 | client_id = client_id or "eeg-esp32-" + rand_str(l=5) 25 | client = MQTTClient( 26 | client_id=client_id, server=server, port=port, keepalive=6000, ssl=False 27 | ) 28 | 29 | callback = callback or default_sub_cb 30 | client.set_callback(callback) 31 | 32 | return client 33 | 34 | 35 | def get_default_topic(): 36 | return env_vars.get("MQTT_DEFAULT_TOPIC") 37 | 38 | 39 | def pack_payload(raw_data, decoded_data, user_id=None, session_id=None): 40 | import utime as time 41 | 42 | payload = { 43 | "eeg_data": raw_data, 44 | "eeg_data_len": len(raw_data), 45 | "decoded_eeg_data": decoded_data, 46 | "timestamp": time.ticks_us(), 47 | } 48 | if session_id is not None: 49 | payload["session_id"] = session_id 50 | 51 | if user_id is not None: 52 | payload["user_id"] = user_id 53 | 54 | return json.dumps(payload) 55 | -------------------------------------------------------------------------------- /micropython/lib/peripherals.py: -------------------------------------------------------------------------------- 1 | import machine 2 | from lib.scheduling import LedFlasher 3 | from lib.utils import update_buffer 4 | 5 | import utime as time 6 | import gc 7 | 8 | DEFAULT_SPI_PARAMS = { 9 | "spi_num": 2, 10 | "sck": 18, 11 | "mosi": 23, 12 | "miso": 19, 13 | "output_amp_gain": 100, # value between 0-255 controlling gain of output amplifier 14 | } 15 | 16 | DEFAULT_ADC_PARAMS = { 17 | "adc_pin": 33, 18 | "atten": machine.ADC.ATTN_11DB, 19 | "width": machine.ADC.WIDTH_12BIT, 20 | "buffer_size": 256, 21 | } 22 | 23 | DEFAULT_LED_CONFIG = {"green": 26, "red": 13} 24 | 25 | DEFAULT_BTN_CONFIG = {"btn_a": 32, "btn_b": 34} 26 | 27 | 28 | DEFAULT_BTN_CONFIG = {"btn_a": 32, "btn_b": 34} 29 | 30 | del machine # only needed for above imports 31 | gc.collect 32 | 33 | 34 | class PeripheralManager: 35 | def __init__( 36 | self, 37 | adc_params=None, 38 | spi_params=None, 39 | led_config=None, 40 | btn_config=None, 41 | verbose=True, 42 | ): 43 | from machine import Pin 44 | 45 | self.verbose = verbose 46 | self._adc_params = DEFAULT_ADC_PARAMS 47 | self._spi_params = DEFAULT_SPI_PARAMS 48 | self._led_config = DEFAULT_LED_CONFIG 49 | self._btn_config = DEFAULT_BTN_CONFIG 50 | 51 | self._adc_params.update(adc_params or {}) 52 | self._spi_params.update(spi_params or {}) 53 | self._led_config.update(led_config or {}) 54 | self._btn_config.update(btn_config or {}) 55 | 56 | # init LEDs 57 | self.leds = {} 58 | for label, pin in self._led_config.items(): 59 | self.leds[label] = Pin(pin, Pin.OUT) 60 | 61 | # init buttons 62 | self.buttons = {} 63 | for label, pin in self._btn_config.items(): 64 | self.buttons[label] = Pin(pin, Pin.IN) 65 | 66 | self._adc_buffer = [0.0 for i in range(self._adc_params["buffer_size"])] 67 | self._timing_buffer = [] 68 | self._adc_scheduler = None 69 | 70 | gc.collect() 71 | 72 | def init(self): 73 | from machine import Pin, SPI, ADC 74 | 75 | self._adc = ADC( 76 | Pin(self._adc_params["adc_pin"]) 77 | ) # create ADC object on GPIO 33 78 | self._adc.atten(self._adc_params["atten"]) 79 | self._adc.width(self._adc_params["width"]) 80 | if self.verbose: 81 | print("ADC initialised") 82 | 83 | # define spi pins and init 84 | # spi clock frequency = 10 MHz, clock idle status = LOW 85 | # spi clock (SCK) pin = GPIO 18 86 | # spi master output slave input (mosi) pin = GPIO 23 87 | # spi master input slave output (miso) pin = GPIO 19 88 | # local digiPot select pin = GPIO 5 89 | # set up button pin = GPIO 21 90 | get_param = lambda key: Pin(self._spi_params[key]) 91 | temp_spi_params = {key: get_param(key) for key in ["sck", "miso", "mosi"]} 92 | self._spi = SPI( 93 | self._spi_params["spi_num"], 94 | baudrate=10000000, 95 | polarity=0, 96 | phase=0, 97 | **temp_spi_params 98 | ) 99 | if self.verbose: 100 | print("SPI initialised") 101 | 102 | output_gain = max(min(self._spi_params["output_amp_gain"], 255), 0) 103 | self.spi_write(output_gain) 104 | if self.verbose: 105 | print( 106 | "DigiPot set to {0} = gain of {1}".format( 107 | output_gain, 1.745 + (255 - output_gain) / (19.2 - 1.745) 108 | ) 109 | ) 110 | 111 | gc.collect() 112 | 113 | @property 114 | def adc_running(self): 115 | return ( 116 | not self._adc_scheduler is None 117 | ) # best proxy we have: no native micropython method 118 | 119 | def flash_led(self, label, freq, duration_sec): 120 | led = self.get_led(label) 121 | init_state = led.value() 122 | flasher = LedFlasher(0, freq, led) 123 | flasher.run_for_duration(duration_sec) 124 | led.value(init_state) 125 | 126 | def write_led(self, label, val): 127 | led = self.get_led(label) 128 | led.value(val) 129 | 130 | def read_btn(self, label): 131 | btn = self.get_btn(label) 132 | return btn.value() 133 | 134 | def get_led(self, label): 135 | if label not in self.leds: 136 | raise ValueError( 137 | "LED with label {0} not found. Valid options are: {1}".format( 138 | label, self.leds.keys() 139 | ) 140 | ) 141 | return self.leds[label] 142 | 143 | def get_btn(self, label): 144 | if label not in self.buttons: 145 | raise ValueError( 146 | "Button with label {0} not found. Valid options are: {1}".format( 147 | label, self.buttons.keys() 148 | ) 149 | ) 150 | return self.buttons[label] 151 | 152 | def get_adc(self): 153 | return self._adc 154 | 155 | def get_spi(self): 156 | return self._spi 157 | 158 | def read_adc_buffer(self): 159 | return self._adc_buffer 160 | 161 | def read_timing_buffer(self): 162 | return self._timing_buffer 163 | 164 | def spi_write(self, payload): 165 | # data must be in list format, and can be of arbitrary length e.g. [0x00,0x01,x0x02...etc] 166 | # devices can be added ad infinitum, with a unique id e.g. "lo" or "mix" 167 | data = bytearray([17, payload]) 168 | self._spi.write(data) 169 | 170 | def adc_read_to_buff(self, size=-1, record_timing=False): 171 | """Read adc and write to internal buffer. 172 | 173 | Args: 174 | size (int, optional): number samples to take. If -1, buffer will be filled. Defaults to -1. 175 | """ 176 | if record_timing: 177 | delta = time.ticks_us() 178 | n_samples = 20 179 | if len(self._timing_buffer) == 0: 180 | self._timing_buffer = [0 for i in range(n_samples)] 181 | 182 | update_buffer(self._timing_buffer, delta, n_samples) 183 | 184 | buff_size = self._adc_params["buffer_size"] 185 | if size < 0 or size > buff_size: 186 | size = buff_size 187 | data = [self._adc.read() for i in range(size)] 188 | update_buffer(self._adc_buffer, data, buff_size) 189 | 190 | def adc_read(self, size=1): 191 | return [self._adc.read() for i in range(size)] 192 | -------------------------------------------------------------------------------- /micropython/lib/runner.py: -------------------------------------------------------------------------------- 1 | from .peripherals import PeripheralManager 2 | from ulab import numpy as np 3 | from logging import BaseLogger, MQTTLogger, HTTPLogger, logger_types 4 | 5 | import gc 6 | import config 7 | 8 | from micropython import schedule 9 | 10 | 11 | class Runner: 12 | def __init__(self, decoding_algo, buffer_size=256, stimulus_freqs=None) -> None: 13 | if stimulus_freqs is None: 14 | self.stim_freqs = config.STIM_FREQS # assign defaults 15 | 16 | self.decoding_algo = decoding_algo 17 | 18 | self.base_sample_freq = config.ADC_SAMPLE_FREQ 19 | self.downsampled_freq = config.DOWNSAMPLED_FREQ 20 | 21 | self.preprocessing_enabled = config.PREPROCESSING 22 | 23 | if self.preprocessing_enabled: 24 | self.downsampled_freq = config.DOWNSAMPLED_FREQ 25 | 26 | self.sample_counter = 0 27 | self.buffer_size = buffer_size # TODO: update this to be populated dynamically 28 | 29 | self.output_buffer = [0.0 for i in range(self.buffer_size)] 30 | self.decoded_output = {} 31 | 32 | self.is_setup = False 33 | self.is_sampling = False 34 | self.is_logging = False 35 | 36 | def setup(self, spi_params=None, adc_params=None, log_period=5, logger_type=None): 37 | from machine import freq 38 | 39 | freq(config.BASE_CLK_FREQ) # set the CPU frequency 40 | 41 | self._init_decoder() 42 | gc.collect() 43 | 44 | self._init_peripherals(spi_params, adc_params) 45 | gc.collect() 46 | 47 | self._setup_logger(log_period, logger_type) 48 | gc.collect() 49 | 50 | self.is_setup = True 51 | 52 | def run(self): 53 | if not self.is_setup: 54 | raise ValueError("Runner not setup. Call `.setup()` before running.") 55 | 56 | if self.decoder.requires_calibration and not self.is_calibrated: 57 | print("Warning: decoder has not been fully calibrated. Please provide calibration data for\ 58 | each stimulus frequency and call `.calibrate()`") 59 | 60 | self.start_sample_timer() 61 | 62 | if self.logger is not None: 63 | self.start_logger() 64 | 65 | def stop(self): 66 | if self.is_sampling: 67 | self.stop_sample_timer() 68 | 69 | if self.is_logging: 70 | self.stop_logger() 71 | 72 | def preprocess_data(self, signal): 73 | 74 | """Preprocess incoming signal before decoding algorithms. 75 | This involves applying a bandpass filter to isolate the target SSVEP range 76 | and then downsampling the signal to the Nyquist boundary. 77 | 78 | Returns: 79 | [np.ndarray]: filtered and downsampled signal 80 | """ 81 | from lib.signal import sos_filter 82 | 83 | ds_factor = self.downsampling_factor 84 | signal = np.array(signal) - np.mean(signal) # remove DC component 85 | 86 | # downsample filtered signal by only selecting every `ds_factor` sample 87 | return sos_filter(signal, fs=self.base_sample_freq)[::ds_factor] 88 | 89 | def calibrate(self, calibration_data_map): 90 | self.decoder.calibrate(calibration_data_map) 91 | 92 | def decode(self, *args): 93 | """ 94 | Run decoding on current state of output buffer. 95 | 96 | Note that `*args` is specified but not used directly: this allows 97 | this function to be called using `micropython.schedule` which 98 | requires the scheduled func to accept an argument. 99 | """ 100 | data = np.array(self.output_buffer) 101 | data = data.reshape((1, len(data))) # reshape to row vector 102 | gc.collect() 103 | 104 | result = self.decoder.classify(data) 105 | 106 | # note: need to be careful not to change the memory address of this variable using direct 107 | # assignment since the logger depends on this reference. Also would just be inefficient. 108 | self.decoded_output.update( 109 | {freq: round(corr, 5) for freq, corr in result.items()} 110 | ) 111 | gc.collect() 112 | return self.decoded_output 113 | 114 | def sample_callback(self, *args, **kwargs): 115 | from lib.utils import update_buffer 116 | 117 | self.periph_manager.adc_read_to_buff(size=1) 118 | self.sample_counter += 1 119 | 120 | # this will only be true every 1s once buffer fills 121 | if self.sample_counter >= self.buffer_size: 122 | self.periph_manager.write_led("red", 1) 123 | data = self._read_internal_buffer(preprocess=self.preprocessing_enabled) 124 | update_buffer( 125 | self.output_buffer, list(data), self.buffer_size, inplace=True 126 | ) 127 | self.sample_counter = 0 128 | self.periph_manager.write_led("red", 0) 129 | 130 | # TODO: workout how to run decoding in another handler as 131 | # this could take a non-negligible amount of time which 132 | # would disrupt consistency of sampling freq. For now, 133 | # we can schedule this function to run 'soon' while allowing other 134 | # ISRs to interrupt it if need be. 135 | try: 136 | schedule(self.decode, None) 137 | except RuntimeError: 138 | # if schedule queue is full, run now 139 | self.decode() 140 | 141 | def read_output_buffer(self): 142 | return self.output_buffer 143 | 144 | def start_logger(self): 145 | if self.logger is not None: 146 | self.logger.start() 147 | self.is_logging = True 148 | 149 | def stop_logger(self): 150 | if self.logger is not None: 151 | self.logger.stop() 152 | self.is_logging = False 153 | 154 | def start_sample_timer(self): 155 | from machine import Timer 156 | 157 | self.sample_timer = Timer(0) 158 | self.sample_timer.init( 159 | freq=self.base_sample_freq, callback=self.sample_callback 160 | ) 161 | self.is_sampling = True 162 | 163 | def stop_sample_timer(self): 164 | if self.sample_timer is not None: 165 | self.sample_timer.deinit() 166 | self.is_sampling = False 167 | 168 | @property 169 | def downsampling_factor(self): 170 | return self.base_sample_freq // self.downsampled_freq 171 | 172 | @property 173 | def is_calibrated(self): 174 | return self.decoder.is_calibrated 175 | 176 | def _read_internal_buffer(self, preprocess=False): 177 | data = self.periph_manager.read_adc_buffer() 178 | if preprocess and len(data) > 1: 179 | data = self.preprocess_data(data) 180 | return data 181 | 182 | def _init_peripherals(self, spi_params, adc_params): 183 | 184 | self.periph_manager = PeripheralManager( 185 | spi_params=spi_params, adc_params=adc_params 186 | ) 187 | self.periph_manager.init() 188 | 189 | def _init_decoder(self): 190 | from lib.decoding import DecoderSSVEP 191 | 192 | # note: downsampled_freq is same as base sampling freq if 193 | # preprocessing is disabled 194 | self.decoder = DecoderSSVEP(self.stim_freqs, self.downsampled_freq, self.decoding_algo) 195 | 196 | def _setup_logger(self, log_period, logger_type): 197 | if logger_type is not None: 198 | if logger_type != logger_types.SERIAL: 199 | print( 200 | "Warning: only the `SERIAL` logger type is available offline. Defaulting to this." 201 | ) 202 | self.logger = BaseLogger( 203 | log_period, self.decoded_output, self.output_buffer 204 | ) 205 | else: 206 | self.logger = None 207 | 208 | 209 | class OnlineRunner(Runner): 210 | def setup( 211 | self, 212 | spi_params=None, 213 | adc_params=None, 214 | log_period=5, 215 | logger_type=None, 216 | **logger_params 217 | ): 218 | super().setup(spi_params=spi_params, adc_params=adc_params) 219 | gc.collect() 220 | 221 | self.configure_wifi(env_path="lib/.env") 222 | gc.collect() 223 | 224 | self._setup_logger(log_period, logger_type, **logger_params) 225 | 226 | def _setup_logger(self, log_period, logger_type, **logger_params): 227 | if logger_type is not None: 228 | base_logger_args = [log_period, self.decoded_output, self.output_buffer] 229 | if logger_type == logger_types.MQTT: 230 | self.logger = MQTTLogger(*base_logger_args, **logger_params) 231 | elif logger_type == logger_types.HTTP: 232 | self.logger = HTTPLogger(*base_logger_args, **logger_params) 233 | else: 234 | self.logger = BaseLogger(*base_logger_args) 235 | else: 236 | self.logger = None 237 | 238 | @staticmethod 239 | def configure_wifi(env_path=".env"): 240 | 241 | from lib.utils import connect_wifi, load_env_vars 242 | 243 | env_vars = load_env_vars(env_path) 244 | ssid = env_vars.get("WIFI_SSID") 245 | password = env_vars.get("WIFI_PASSWORD") 246 | connect_wifi(ssid, password) 247 | -------------------------------------------------------------------------------- /micropython/lib/scheduling.py: -------------------------------------------------------------------------------- 1 | from machine import Pin, Timer 2 | import micropython 3 | import ujson as json 4 | import utime as time 5 | 6 | micropython.alloc_emergency_exception_buf(100) 7 | 8 | 9 | class ScheduledFunc: 10 | def __init__(self, timer_num, freq): 11 | self.freq = freq 12 | self.tim = Timer(timer_num) 13 | 14 | def start(self, callback=None): 15 | callback = ( 16 | callback or self.cb 17 | ) # use callback if supplied, else default class callback 18 | self.tim.init( 19 | freq=self.freq, callback=callback 20 | ) # !! freq kwarg was only added in mpy v1.16 21 | 22 | def run_for_duration(self, duration_sec): 23 | self.start() 24 | t0 = time.time() 25 | while time.time() - t0 < duration_sec: 26 | pass 27 | self.stop() 28 | 29 | def stop(self): 30 | self.tim.deinit() 31 | 32 | def cb(self, timer, *args): 33 | pass 34 | 35 | 36 | class LedFlasher(ScheduledFunc): 37 | def __init__(self, timer_num, freq, led): 38 | super().__init__(timer_num, freq * 2) # so that LED comes on at f=freq 39 | self.led = led 40 | 41 | def cb(self, timer, *args): 42 | self.led.value(not self.led.value()) 43 | 44 | 45 | class WsDataScheduler(ScheduledFunc): 46 | def __init__(self, freq, ws_server, led_pin=5, timer_num=0): 47 | super().__init__(timer_num, freq) 48 | self.led = Pin(led_pin, Pin.OUT) 49 | self.ws_server = ws_server 50 | 51 | def start(self): 52 | super().start() 53 | self.ws_server.start() 54 | 55 | def stop(self): 56 | super().stop() 57 | self.ws_server.stop() 58 | 59 | def send_data(self, *args): 60 | self.ws_server.process_all() 61 | data = None 62 | self.ws_server.broadcast(json.dumps(data)) 63 | 64 | # self.ws_server.broadcast(data.tobytes()) 65 | 66 | def cb(self, timer, *args): 67 | self.led.value(not self.led.value()) 68 | micropython.schedule( 69 | self.send_data, self 70 | ) # see https://docs.micropython.org/en/latest/library/micropython.html?highlight=schedule#micropython.schedule 71 | -------------------------------------------------------------------------------- /micropython/lib/signal.py: -------------------------------------------------------------------------------- 1 | from ulab import numpy as np 2 | from ulab import scipy as spy 3 | 4 | """ 5 | The digital filter coefficients below were designed using Scipy to the same 6 | specification: 7 | 8 | Elliptical 10th order bandpass filter with corner frequencies at (4, 28)Hz, 9 | 0.2dB passband ripple and 80dB stopband atten 10 | """ 11 | 12 | SOS_SSVEP_BANDPASS_256HZ = np.array( 13 | [ 14 | [ 15 | 5.18442631e-04, 16 | 5.91022291e-04, 17 | 5.18442631e-04, 18 | 1.00000000e00, 19 | -1.58700686e00, 20 | 6.47826110e-01, 21 | ], 22 | [ 23 | 1.00000000e00, 24 | -6.71721317e-01, 25 | 1.00000000e00, 26 | 1.00000000e00, 27 | -1.56164716e00, 28 | 7.42956116e-01, 29 | ], 30 | [ 31 | 1.00000000e00, 32 | -1.19862825e00, 33 | 1.00000000e00, 34 | 1.00000000e00, 35 | -1.53434369e00, 36 | 8.53024717e-01, 37 | ], 38 | [ 39 | 1.00000000e00, 40 | -1.36462221e00, 41 | 1.00000000e00, 42 | 1.00000000e00, 43 | -1.52074686e00, 44 | 9.31086238e-01, 45 | ], 46 | [ 47 | 1.00000000e00, 48 | -1.41821305e00, 49 | 1.00000000e00, 50 | 1.00000000e00, 51 | -1.52570664e00, 52 | 9.80264626e-01, 53 | ], 54 | ] 55 | ) 56 | 57 | SOS_SSVEP_BANDPASS_128HZ = np.array( 58 | [ 59 | [0.00489814, 0.00882672, 0.00489814, 1.0, -1.12754282, 0.37507747], 60 | [1.0, 0.89364345, 1.0, 1.0, -0.86464138, 0.5663009], 61 | [1.0, 0.27438961, 1.0, 1.0, -0.59631233, 0.76500326], 62 | [1.0, -0.00656791, 1.0, 1.0, -0.4363727, 0.89332053], 63 | [1.0, -0.11037337, 1.0, 1.0, -0.37229848, 0.96976145], 64 | ] 65 | ) 66 | 67 | 68 | def sos_filter(x, sos_coeffs=None, fs=256): 69 | if sos_coeffs is None: 70 | if fs == 256: 71 | sos_coeffs = SOS_SSVEP_BANDPASS_256HZ 72 | elif fs == 128: 73 | sos_coeffs = SOS_SSVEP_BANDPASS_128HZ 74 | else: 75 | raise ValueError( 76 | "Unepexcted sampling frequency. Only have SOS filter weights for fs = 64Hz or 256Hz" 77 | ) 78 | 79 | return spy.signal.sosfilt(sos_coeffs, x) 80 | -------------------------------------------------------------------------------- /micropython/lib/synthetic.py: -------------------------------------------------------------------------------- 1 | import urandom 2 | from ulab import numpy as np 3 | 4 | def synth_x(f, Ns, noise_power=0.5, fs=256): 5 | """ 6 | generate a synthetic signal vector 7 | 8 | args: 9 | Ns [int]: number of samples (time samples) 10 | noise_power [float]: variance of WGN noise distribution 11 | """ 12 | t = np.arange(0, Ns/fs, 1/fs) 13 | return np.sin(t*2*np.pi*f)*(1+urandom.random()*noise_power) 14 | 15 | def synth_X(f, Nc, Ns, Nt=1, noise_power=0.5, fs=256, f_std=0.02): 16 | """ 17 | Generate a matrix of several variations of the same target signal. This is used 18 | to simulate the measurement of a common signal over multiple EEG channels 19 | that have different SNR characteristics. 20 | 21 | args: 22 | f [float]: target frequency of synthetic signal (Hz) 23 | Nc [int]: number of channels 24 | Ns [int]: number of samples (time samples) 25 | Ns [int]: number of iid trials 26 | noise_power [float]: variance of WGN noise distribution 27 | fs [float]: sampling frequency (Hz) 28 | f_std [float]: standard dev. of freq. in generated signal across channels to simulate interference from other frequency components over different channels 29 | """ 30 | def _synth(): 31 | X = [] 32 | for i in range(Nc): # simulate noisy sinusoids with varying SNR across Nc channels 33 | f_i = f*(1+urandom.random()*f_std) 34 | x = synth_x(f_i, Ns, noise_power=noise_power, fs=fs) 35 | 36 | X.append(x) 37 | 38 | return np.array(X) 39 | 40 | if Nt <= 1: 41 | return _synth() 42 | else: 43 | trials = [] 44 | for i in range(Nt): 45 | trials.append(_synth().flatten()) 46 | 47 | return np.array(trials) -------------------------------------------------------------------------------- /micropython/lib/umqtt.py: -------------------------------------------------------------------------------- 1 | import usocket as socket 2 | import ustruct as struct 3 | from ubinascii import hexlify 4 | 5 | 6 | class MQTTException(Exception): 7 | pass 8 | 9 | 10 | class MQTTClient: 11 | def __init__( 12 | self, 13 | client_id, 14 | server, 15 | port=0, 16 | user=None, 17 | password=None, 18 | keepalive=0, 19 | ssl=False, 20 | ssl_params={}, 21 | ): 22 | if port == 0: 23 | port = 8883 if ssl else 1883 24 | self.client_id = client_id 25 | self.sock = None 26 | self.server = server 27 | self.port = port 28 | self.ssl = ssl 29 | self.ssl_params = ssl_params 30 | self.pid = 0 31 | self.cb = None 32 | self.user = user 33 | self.pswd = password 34 | self.keepalive = keepalive 35 | self.lw_topic = None 36 | self.lw_msg = None 37 | self.lw_qos = 0 38 | self.lw_retain = False 39 | 40 | def _send_str(self, s): 41 | self.sock.write(struct.pack("!H", len(s))) 42 | self.sock.write(s) 43 | 44 | def _recv_len(self): 45 | n = 0 46 | sh = 0 47 | while 1: 48 | b = self.sock.read(1)[0] 49 | n |= (b & 0x7F) << sh 50 | if not b & 0x80: 51 | return n 52 | sh += 7 53 | 54 | def set_callback(self, f): 55 | self.cb = f 56 | 57 | def set_last_will(self, topic, msg, retain=False, qos=0): 58 | assert 0 <= qos <= 2 59 | assert topic 60 | self.lw_topic = topic 61 | self.lw_msg = msg 62 | self.lw_qos = qos 63 | self.lw_retain = retain 64 | 65 | def connect(self, clean_session=True): 66 | self.sock = socket.socket() 67 | addr = socket.getaddrinfo(self.server, self.port)[0][-1] 68 | print("Attemptint to connect to socket addr: ", addr) 69 | self.sock.connect(addr) 70 | if self.ssl: 71 | import ussl 72 | 73 | self.sock = ussl.wrap_socket(self.sock, **self.ssl_params) 74 | premsg = bytearray(b"\x10\0\0\0\0\0") 75 | msg = bytearray(b"\x04MQTT\x04\x02\0\0") 76 | 77 | sz = 10 + 2 + len(self.client_id) 78 | msg[6] = clean_session << 1 79 | if self.user is not None: 80 | sz += 2 + len(self.user) + 2 + len(self.pswd) 81 | msg[6] |= 0xC0 82 | if self.keepalive: 83 | assert self.keepalive < 65536 84 | msg[7] |= self.keepalive >> 8 85 | msg[8] |= self.keepalive & 0x00FF 86 | if self.lw_topic: 87 | sz += 2 + len(self.lw_topic) + 2 + len(self.lw_msg) 88 | msg[6] |= 0x4 | (self.lw_qos & 0x1) << 3 | (self.lw_qos & 0x2) << 3 89 | msg[6] |= self.lw_retain << 5 90 | 91 | i = 1 92 | while sz > 0x7F: 93 | premsg[i] = (sz & 0x7F) | 0x80 94 | sz >>= 7 95 | i += 1 96 | premsg[i] = sz 97 | 98 | self.sock.write(premsg, i + 2) 99 | self.sock.write(msg) 100 | # print(hex(len(msg)), hexlify(msg, ":")) 101 | self._send_str(self.client_id) 102 | if self.lw_topic: 103 | self._send_str(self.lw_topic) 104 | self._send_str(self.lw_msg) 105 | if self.user is not None: 106 | self._send_str(self.user) 107 | self._send_str(self.pswd) 108 | resp = self.sock.read(4) 109 | print("resp: ", resp) 110 | 111 | assert resp[0] == 0x20 and resp[1] == 0x02 112 | if resp[3] != 0: 113 | raise MQTTException(resp[3]) 114 | return resp[2] & 1 115 | 116 | def disconnect(self): 117 | self.sock.write(b"\xe0\0") 118 | self.sock.close() 119 | 120 | def ping(self): 121 | self.sock.write(b"\xc0\0") 122 | 123 | def publish(self, topic, msg, retain=False, qos=0): 124 | pkt = bytearray(b"\x30\0\0\0") 125 | pkt[0] |= qos << 1 | retain 126 | sz = 2 + len(topic) + len(msg) 127 | if qos > 0: 128 | sz += 2 129 | assert sz < 2097152 130 | i = 1 131 | while sz > 0x7F: 132 | pkt[i] = (sz & 0x7F) | 0x80 133 | sz >>= 7 134 | i += 1 135 | pkt[i] = sz 136 | print(hex(len(pkt)), hexlify(pkt, ":")) 137 | self.sock.write(pkt, i + 1) 138 | self._send_str(topic) 139 | if qos > 0: 140 | self.pid += 1 141 | pid = self.pid 142 | struct.pack_into("!H", pkt, 0, pid) 143 | self.sock.write(pkt, 2) 144 | self.sock.write(msg) 145 | if qos == 1: 146 | while 1: 147 | op = self.wait_msg() 148 | if op == 0x40: 149 | sz = self.sock.read(1) 150 | assert sz == b"\x02" 151 | rcv_pid = self.sock.read(2) 152 | rcv_pid = rcv_pid[0] << 8 | rcv_pid[1] 153 | if pid == rcv_pid: 154 | return 155 | elif qos == 2: 156 | assert 0 157 | 158 | def subscribe(self, topic, qos=0): 159 | assert self.cb is not None, "Subscribe callback is not set" 160 | pkt = bytearray(b"\x82\0\0\0") 161 | self.pid += 1 162 | struct.pack_into("!BH", pkt, 1, 2 + 2 + len(topic) + 1, self.pid) 163 | # print(hex(len(pkt)), hexlify(pkt, ":")) 164 | self.sock.write(pkt) 165 | self._send_str(topic) 166 | self.sock.write(qos.to_bytes(1, "little")) 167 | while 1: 168 | op = self.wait_msg() 169 | if op == 0x90: 170 | resp = self.sock.read(4) 171 | # print(resp) 172 | assert resp[1] == pkt[2] and resp[2] == pkt[3] 173 | if resp[3] == 0x80: 174 | raise MQTTException(resp[3]) 175 | return 176 | 177 | # Wait for a single incoming MQTT message and process it. 178 | # Subscribed messages are delivered to a callback previously 179 | # set by .set_callback() method. Other (internal) MQTT 180 | # messages processed internally. 181 | def wait_msg(self): 182 | res = self.sock.read(1) 183 | self.sock.setblocking(True) 184 | if res is None: 185 | return None 186 | if res == b"": 187 | raise OSError(-1) 188 | if res == b"\xd0": # PINGRESP 189 | sz = self.sock.read(1)[0] 190 | assert sz == 0 191 | return None 192 | op = res[0] 193 | if op & 0xF0 != 0x30: 194 | return op 195 | sz = self._recv_len() 196 | topic_len = self.sock.read(2) 197 | topic_len = (topic_len[0] << 8) | topic_len[1] 198 | topic = self.sock.read(topic_len) 199 | sz -= topic_len + 2 200 | if op & 6: 201 | pid = self.sock.read(2) 202 | pid = pid[0] << 8 | pid[1] 203 | sz -= 2 204 | msg = self.sock.read(sz) 205 | self.cb(topic, msg) 206 | if op & 6 == 2: 207 | pkt = bytearray(b"\x40\x02\0\0") 208 | struct.pack_into("!H", pkt, 2, pid) 209 | self.sock.write(pkt) 210 | elif op & 6 == 4: 211 | assert 0 212 | 213 | # Checks whether a pending message from server is available. 214 | # If not, returns immediately with None. Otherwise, does 215 | # the same processing as wait_msg. 216 | def check_msg(self): 217 | self.sock.setblocking(False) 218 | return self.wait_msg() 219 | -------------------------------------------------------------------------------- /micropython/lib/utils.py: -------------------------------------------------------------------------------- 1 | class Enum(set): 2 | """ 3 | Class to simulate the behaviour of an enum type 4 | """ 5 | 6 | def __getattr__(self, name): 7 | if name in self: 8 | return name 9 | raise AttributeError 10 | 11 | 12 | def delay_ms(t): 13 | import utime 14 | 15 | t0 = utime.time() 16 | while (utime.time() - t0) * 1000 < t: 17 | pass # TODO: investigate if this will actually free core during delay 18 | 19 | 20 | def update_buffer(buf, el, max_size, inplace=True): 21 | if max_size != len(buf): 22 | inplace = False 23 | 24 | if type(el) in [float, int]: 25 | el = [el] 26 | el = el[-max_size:] 27 | tmp = (buf[-(max_size - len(el)) :] + el)[-max_size:] 28 | if inplace: 29 | for i, el in enumerate(tmp): 30 | buf[i] = el 31 | return None 32 | return tmp 33 | 34 | 35 | def connect_wifi(ssid, password): 36 | import network 37 | 38 | wlan = network.WLAN(network.STA_IF) 39 | wlan.active(True) 40 | if not wlan.isconnected(): 41 | print("connecting to network...") 42 | wlan.connect(ssid, password) 43 | while not wlan.isconnected(): # okay that this is blocking 44 | pass 45 | print("network config:", wlan.ifconfig()) 46 | return wlan 47 | 48 | 49 | def load_env_vars(path): 50 | import ure as re 51 | 52 | envre = re.compile(r"""^([^\s=]+)=(?:[\s"']*)(.+?)(?:[\s"']*)$""") 53 | result = {} 54 | with open(path) as ins: 55 | for line in ins: 56 | match = envre.match(line) 57 | if match is not None: 58 | result[match.group(1)] = match.group(2).replace("\n", "") 59 | return result 60 | 61 | 62 | def write_json(filename, data): 63 | import ujson as json 64 | 65 | with open(filename, "w") as f: 66 | json.dump(data, f) 67 | 68 | 69 | def read_json(filename): 70 | import ujson as json 71 | 72 | with open(filename) as f: 73 | return json.load(f) 74 | -------------------------------------------------------------------------------- /micropython/lib/websocket/Multiserver/ws_multiserver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import websocket_helper 3 | from time import sleep 4 | 5 | from ws_server import WebSocketServer 6 | from ws_connection import WebSocketConnection 7 | 8 | 9 | class WebSocketMultiServer(WebSocketServer): 10 | http_codes = { 11 | 200: "OK", 12 | 404: "Not Found", 13 | 500: "Internal Server Error", 14 | 503: "Service Unavailable", 15 | } 16 | 17 | mime_types = { 18 | "jpg": "image/jpeg", 19 | "jpeg": "image/jpeg", 20 | "png": "image/png", 21 | "gif": "image/gif", 22 | "html": "text/html", 23 | "htm": "text/html", 24 | "css": "text/css", 25 | "js": "application/javascript", 26 | } 27 | 28 | def __init__(self, index_page, max_connections=1): 29 | super().__init__(index_page, max_connections) 30 | dir_idx = index_page.rfind("/") 31 | self._web_dir = index_page[0:dir_idx] if dir_idx > 0 else "/" 32 | 33 | def _accept_conn(self, listen_sock): 34 | cl, remote_addr = self._listen_s.accept() 35 | print("Client connection from:", remote_addr) 36 | 37 | if len(self._clients) >= self._max_connections: 38 | # Maximum connections limit reached 39 | cl.setblocking(True) 40 | self._generate_static_page(cl, 503, "503 Too Many Connections") 41 | return 42 | 43 | requested_file = None 44 | data = cl.recv(64).decode() 45 | if ( 46 | data 47 | and "Upgrade: websocket" not in data.split("\r\n") 48 | and "GET" == data.split(" ")[0] 49 | ): 50 | # data should looks like GET /index.html HTTP/1.1\r\nHost: 19" 51 | # requested file is on second position in data, ignore all get parameters after question mark 52 | requested_file = data.split(" ")[1].split("?")[0] 53 | requested_file = ( 54 | self._page if requested_file in [None, "/"] else requested_file 55 | ) 56 | 57 | try: 58 | websocket_helper.server_handshake(cl) 59 | self._clients.append( 60 | self._make_client( 61 | WebSocketConnection(remote_addr, cl, self.remove_connection) 62 | ) 63 | ) 64 | except OSError: 65 | if requested_file: 66 | cl.setblocking(True) 67 | self._serve_file(requested_file, cl) 68 | else: 69 | self._generate_static_page(cl, 500, "500 Internal Server Error [2]") 70 | 71 | def _serve_file(self, requested_file, c_socket): 72 | print("### Serving file: {}".format(requested_file)) 73 | try: 74 | # check if file exists in web directory 75 | path = requested_file.split("/") 76 | filename = path[-1] 77 | subdir = "/" + "/".join(path[1:-1]) if len(path) > 2 else "" 78 | 79 | if filename not in os.listdir(self._web_dir + subdir): 80 | self._generate_static_page(c_socket, 404, "404 Not Found") 81 | return 82 | 83 | # Create path based on web root directory 84 | file_path = self._web_dir + requested_file 85 | length = os.stat(file_path)[6] 86 | c_socket.sendall(self._generate_headers(200, file_path, length)) 87 | # Send file by chunks to prevent large memory consumption 88 | chunk_size = 1024 89 | with open(file_path, "rb") as f: 90 | while True: 91 | data = f.read(chunk_size) 92 | c_socket.sendall(data) 93 | if len(data) < chunk_size: 94 | break 95 | sleep(0.1) 96 | c_socket.close() 97 | except OSError: 98 | self._generate_static_page(c_socket, 500, "500 Internal Server Error [2]") 99 | 100 | @staticmethod 101 | def _generate_headers(code, filename=None, length=None): 102 | content_type = "text/html" 103 | 104 | if filename: 105 | ext = filename.split(".")[1] 106 | if ext in WebSocketMultiServer.mime_types: 107 | content_type = WebSocketMultiServer.mime_types[ext] 108 | 109 | # Close connection after completing the request 110 | return ( 111 | "HTTP/1.1 {} {}\n" 112 | "Content-Type: {}\n" 113 | "Content-Length: {}\n" 114 | "Server: ESPServer\n" 115 | "Connection: close\n\n".format( 116 | code, WebSocketMultiServer.http_codes[code], content_type, length 117 | ) 118 | ) 119 | 120 | @staticmethod 121 | def _generate_static_page(sock, code, message): 122 | sock.sendall(WebSocketMultiServer._generate_headers(code)) 123 | sock.sendall("

" + message + "

") 124 | sleep(0.1) 125 | sock.close() 126 | -------------------------------------------------------------------------------- /micropython/lib/websocket/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JamesTev/EEG-decoding/c4054dd9d1eac857aedd487a34f177c97d95c0af/micropython/lib/websocket/__init__.py -------------------------------------------------------------------------------- /micropython/lib/websocket/test.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | WebSocket Client 6 | 7 | 8 | 9 | Received:
10 |
11 | (Received content) 12 |
13 | 14 | 25 | 26 | -------------------------------------------------------------------------------- /micropython/lib/websocket/websocket_demo.py: -------------------------------------------------------------------------------- 1 | from websocket.ws_connection import ClientClosedError 2 | from websocket.ws_server import WebSocketServer, WebSocketClient 3 | 4 | 5 | class TestClient(WebSocketClient): 6 | def __init__(self, conn): 7 | super().__init__(conn) 8 | 9 | def process(self): 10 | try: 11 | msg = self.connection.read() 12 | if not msg: 13 | return 14 | msg = msg.decode("utf-8") 15 | print(msg) 16 | self.connection.write(msg) 17 | except ClientClosedError: 18 | self.connection.close() 19 | 20 | 21 | class TestServer(WebSocketServer): 22 | def __init__(self): 23 | super().__init__("test.html", 2) 24 | 25 | def _make_client(self, conn): 26 | return TestClient(conn) 27 | 28 | 29 | # server = TestServer() 30 | # server.start() 31 | # try: 32 | # while True: 33 | # server.process_all() 34 | # except KeyboardInterrupt: 35 | # pass 36 | # server.stop() 37 | -------------------------------------------------------------------------------- /micropython/lib/websocket/ws_connection.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from uwebsocket import websocket 3 | import uselect 4 | 5 | 6 | class ClientClosedError(Exception): 7 | pass 8 | 9 | 10 | class WebSocketConnection: 11 | def __init__(self, addr, s, close_callback): 12 | self.client_close = False 13 | self._need_check = False 14 | 15 | self.address = addr 16 | self.socket = s 17 | self.ws = websocket(s, True) 18 | self.poll = uselect.poll() 19 | self.close_callback = close_callback 20 | 21 | self.socket.setblocking(False) 22 | self.poll.register(self.socket, uselect.POLLIN) 23 | 24 | def read(self): 25 | poll_events = self.poll.poll(0) 26 | 27 | if not poll_events: 28 | return 29 | 30 | # Check the flag for connection hung up 31 | if poll_events[0][1] & uselect.POLLHUP: 32 | self.client_close = True 33 | 34 | msg_bytes = None 35 | try: 36 | msg_bytes = self.ws.read() 37 | except OSError: 38 | self.client_close = True 39 | 40 | # If no bytes => connection closed. See the link below. 41 | # http://stefan.buettcher.org/cs/conn_closed.html 42 | if not msg_bytes or self.client_close: 43 | raise ClientClosedError() 44 | 45 | return msg_bytes 46 | 47 | def write(self, msg): 48 | try: 49 | self.ws.write(msg) 50 | except OSError: 51 | self.client_close = True 52 | 53 | def is_closed(self): 54 | return self.socket is None 55 | 56 | def close(self): 57 | print("Closing connection.") 58 | self.poll.unregister(self.socket) 59 | self.socket.close() 60 | self.socket = None 61 | self.ws = None 62 | if self.close_callback: 63 | self.close_callback(self) 64 | -------------------------------------------------------------------------------- /micropython/lib/websocket/ws_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | import network 4 | import websocket_helper 5 | import uselect 6 | from time import sleep 7 | from websocket.ws_connection import WebSocketConnection, ClientClosedError 8 | 9 | 10 | class WebSocketClient: 11 | def __init__(self, conn): 12 | self.connection = conn 13 | 14 | def process(self): 15 | pass 16 | 17 | 18 | class WebSocketServer: 19 | def __init__(self, page, max_connections=1): 20 | self._listen_s = None 21 | self._listen_poll = None 22 | self._clients = [] 23 | self._max_connections = max_connections 24 | self._page = page 25 | 26 | def _setup_conn(self, port, attempt_wifi_conn=True): 27 | from lib.utils import connect_wifi 28 | 29 | self._listen_s = socket.socket() 30 | self._listen_s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 31 | self._listen_poll = uselect.poll() 32 | 33 | ai = socket.getaddrinfo("0.0.0.0", port) 34 | addr = ai[0][4] 35 | 36 | self._listen_s.bind(addr) 37 | self._listen_s.listen(1) 38 | self._listen_poll.register(self._listen_s) 39 | 40 | connect_wifi() 41 | 42 | for i in (network.AP_IF, network.STA_IF): 43 | iface = network.WLAN(i) 44 | if iface.active(): 45 | print("WebSocket started on ws://%s:%d" % (iface.ifconfig()[0], port)) 46 | return iface 47 | 48 | def _check_new_connections(self, accept_handler): 49 | poll_events = self._listen_poll.poll(0) 50 | if not poll_events: 51 | return 52 | 53 | if poll_events[0][1] & uselect.POLLIN: 54 | accept_handler() 55 | 56 | def _accept_conn(self): 57 | cl, remote_addr = self._listen_s.accept() 58 | print("Client connection from:", remote_addr) 59 | 60 | if len(self._clients) >= self._max_connections: 61 | # Maximum connections limit reached 62 | cl.setblocking(True) 63 | cl.sendall("HTTP/1.1 503 Too many connections\n\n") 64 | cl.sendall("\n") 65 | # TODO: Make sure the data is sent before closing 66 | sleep(0.1) 67 | cl.close() 68 | return 69 | 70 | try: 71 | websocket_helper.server_handshake(cl) 72 | except OSError: 73 | # Not a websocket connection, serve webpage 74 | self._serve_page(cl) 75 | return 76 | 77 | self._clients.append( 78 | self._make_client( 79 | WebSocketConnection(remote_addr, cl, self.remove_connection) 80 | ) 81 | ) 82 | 83 | def _make_client(self, conn): 84 | return WebSocketClient(conn) 85 | 86 | def _serve_page(self, sock): 87 | try: 88 | sock.sendall( 89 | "HTTP/1.1 200 OK\nConnection: close\nServer: WebSocket Server\nContent-Type: text/html\n" 90 | ) 91 | length = os.stat(self._page)[6] 92 | sock.sendall("Content-Length: {}\n\n".format(length)) 93 | # Process page by lines to avoid large strings 94 | with open(self._page, "r") as f: 95 | for line in f: 96 | sock.sendall(line) 97 | except OSError: 98 | # Error while serving webpage 99 | pass 100 | sock.close() 101 | 102 | def stop(self): 103 | if self._listen_poll: 104 | self._listen_poll.unregister(self._listen_s) 105 | self._listen_poll = None 106 | if self._listen_s: 107 | self._listen_s.close() 108 | self._listen_s = None 109 | 110 | for client in self._clients: 111 | client.connection.close() 112 | print("Stopped WebSocket server.") 113 | 114 | def start(self, port=80): 115 | if self._listen_s: 116 | self.stop() 117 | self._setup_conn(port) 118 | print("Started WebSocket server.") 119 | 120 | def process_all(self): 121 | self._check_new_connections(self._accept_conn) 122 | 123 | for client in self._clients: 124 | client.process() 125 | 126 | def remove_connection(self, conn): 127 | for client in self._clients: 128 | if client.connection is conn: 129 | self._clients.remove(client) 130 | return 131 | -------------------------------------------------------------------------------- /micropython/lib/websockets.py: -------------------------------------------------------------------------------- 1 | from websocket.ws_connection import ClientClosedError 2 | from websocket.ws_server import WebSocketServer, WebSocketClient 3 | 4 | 5 | class BasicClient(WebSocketClient): 6 | def __init__(self, conn): 7 | super().__init__(conn) 8 | 9 | def process(self): 10 | try: 11 | msg = self.connection.read() 12 | if not msg: 13 | return 14 | msg = msg.decode("utf-8") 15 | print(msg) 16 | self.connection.write(msg) 17 | except ClientClosedError: 18 | self.connection.close() 19 | 20 | 21 | class BasicServer(WebSocketServer): 22 | def __init__(self): 23 | super().__init__("test.html", 2) 24 | 25 | def _make_client(self, conn): 26 | return BasicClient(conn) 27 | 28 | def broadcast(self, msg): 29 | for client in self._clients: 30 | client.connection.write(msg) 31 | 32 | 33 | def test(): 34 | server = BasicServer() 35 | server.start() 36 | try: 37 | while True: 38 | server.process_all() 39 | except KeyboardInterrupt: 40 | pass 41 | server.stop() 42 | -------------------------------------------------------------------------------- /micropython/main.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from micropython import alloc_emergency_exception_buf 3 | 4 | # allocate exception buffer for ISRs 5 | alloc_emergency_exception_buf(100) 6 | 7 | # enable and configure garbage collection 8 | gc.enable() 9 | gc.collect() 10 | gc.threshold(gc.mem_free() // 4 + gc.mem_alloc()) 11 | -------------------------------------------------------------------------------- /micropython/mpy-esp32-networking.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "\u001b[34mConnecting to --port=/dev/tty.usbserial-02EDKZTP --baud=115200 \u001b[0m\n", 13 | "MicroPython v1.14 on 2021-04-12; ESP32 module with ESP32\n", 14 | "Type \"help()\" for more information.\n", 15 | ">>>[reboot detected 0]repl is in normal command mode\n", 16 | "[\\r\\x03\\x03] b'\\r\\n>>> '\n", 17 | "[\\r\\x01] b'\\r\\n>>> \\r\\nraw REPL; CTRL-B to exit\\r\\n>' \u001b[34mReady.\n", 18 | "\u001b[0m" 19 | ] 20 | } 21 | ], 22 | "source": [ 23 | "%serialconnect to --port=\"/dev/tty.usbserial-02EDKZTP\" --baud=115200" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 281, 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "name": "stdout", 33 | "output_type": "stream", 34 | "text": [ 35 | "connecting to network...\n", 36 | ".network config: ('192.168.0.150', '255.255.255.0', '192.168.0.1', '192.168.0.1')\n" 37 | ] 38 | } 39 | ], 40 | "source": [ 41 | "from lib.utils import connect_wifi\n", 42 | "\n", 43 | "connect_wifi()" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 52, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "from lib.websockets import BasicServer, BasicClient\n", 53 | "from lib.scheduling import WsDataScheduler\n", 54 | "\n", 55 | "server = BasicServer()\n", 56 | "data_scheduler = WsDataScheduler(5, server)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 53, 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "name": "stdout", 66 | "output_type": "stream", 67 | "text": [ 68 | "network config: ('192.168.0.150', '255.255.255.0', '192.168.0.1', '192.168.0.1')\n", 69 | "WebSocket started on ws://192.168.0.150:80\n", 70 | "Started WebSocket server.\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "data_scheduler.start()" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 49, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "name": "stdout", 85 | "output_type": "stream", 86 | "text": [ 87 | "[leftinbuffer] [\"Client connection from: ('192.168.0.107', 59991)\"]\n", 88 | "[leftinbuffer] ['Closing connection.']\n", 89 | "[leftinbuffer] [\"Client connection from: ('192.168.0.107', 60059)\"]\n", 90 | "[leftinbuffer] [\"Client connection from: ('192.168.0.107', 60062)\"]\n", 91 | "Closing connection.\n", 92 | "Stopped WebSocket server.\n" 93 | ] 94 | } 95 | ], 96 | "source": [ 97 | "data_scheduler.stop()" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 276, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "# %fetchfile --print /websocket/ws_connection.py" 107 | ] 108 | } 109 | ], 110 | "metadata": { 111 | "kernelspec": { 112 | "display_name": "MicroPython - USB", 113 | "language": "micropython", 114 | "name": "micropython" 115 | }, 116 | "language_info": { 117 | "codemirror_mode": "python", 118 | "file_extension": ".py", 119 | "mimetype": "text/python", 120 | "name": "micropython" 121 | } 122 | }, 123 | "nbformat": 4, 124 | "nbformat_minor": 4 125 | } 126 | -------------------------------------------------------------------------------- /micropython/mpy-esp32.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## MicroPython ESP32 Experimentation" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "### Establishing connection to target board\n", 15 | "First, make sure you've got the right serial port. On unix-based systems, you can run `ls /dev/tty.*` to see your available serial devices. Replace as necessary below.\n", 16 | "\n", 17 | "This will allow Jupyter (your host computer) to run commands and send/receive information to/from your target board in real time using the MicroPython REPL." 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "\u001b[34mConnecting to --port=/dev/tty.usbserial-02U1W54L --baud=115200 \u001b[0m\n", 30 | "MicroPython d8a7bf8-dirty on 2022-02-09; ESP32 module with ESP32\n", 31 | "Type \"help()\" for more information.\n", 32 | ">>>[reboot detected 0]repl is in normal command mode\n", 33 | "[\\r\\x03\\x03] b'\\r\\n>>> '\n", 34 | "[\\r\\x01] b'\\r\\n>>> \\r\\nraw REPL; CTRL-B to exit\\r\\n>' \u001b[34mReady.\n", 35 | "\u001b[0m" 36 | ] 37 | } 38 | ], 39 | "source": [ 40 | "%serialconnect to --port=\"/dev/tty.usbserial-02U1W54L\" --baud=115200\n", 41 | "# %serialconnect to --port=\"/dev/tty.usbserial-0001\" --baud=115200" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 4, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "Sent 246 lines (8210 bytes) to lib/runner.py.\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "%sendtofile lib/runner.py --source lib/runner.py" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 15, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | "Sent 47 lines (1573 bytes) to lib/synthetic.py.\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "%sendtofile lib/synthetic.py --source lib/synthetic.py" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 3, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "name": "stdout", 85 | "output_type": "stream", 86 | "text": [ 87 | "Sent 238 lines (8870 bytes) to lib/decoding.py.\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "%sendtofile lib/decoding.py --source lib/decoding.py" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "## Using a Runner for experimentation and logging\n", 100 | "The a `Runner` is encapsulates the core functions in this EEG system, including peripheral setup, sampling, signal processing, logging and memory management. The `OnlineRunner` offers mostly the same functionality as the standard `Runner` class, except it allows for logging and other communication with a remote server - either on the Internet or on your local network.\n", 101 | "\n", 102 | "### Offline functionality\n", 103 | "The standard `Runner` is good for testing core functionality without the need for remote logging. See below for initialisation and execution." 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 23, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "ADC initialised\n", 116 | "SPI initialised\n", 117 | "DigiPot set to 100 = gain of 10.62497708393011\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "from lib.runner import Runner\n", 123 | "\n", 124 | "Nc = 1\n", 125 | "Ns = 128\n", 126 | "Nt = 3\n", 127 | "stim_freqs = [7, 10, 12]\n", 128 | "\n", 129 | "# Here, we select the algorithm. Can be one of ['MsetCCA', 'GCCA', 'CCA']\n", 130 | "decoding_algo = 'MsetCCA'\n", 131 | "\n", 132 | "runner = Runner(decoding_algo, buffer_size=Ns) # initialise a base runner\n", 133 | "runner.setup() # setup peripherals and memory buffers" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "### Calibration\n", 141 | "If you are using an algorithm that leverages calibration data (MsetCCA, GCCA), you will need to record some calibration data to fit the decoder model. This is usually only done once off before inference starts. You may want to recalibrate at some semi-regular interval too though. \n", 142 | "\n", 143 | "At the moment, there is not an integrated process to record calibration data in the `Runner` class. You have to record calibration data and provide it to the runner which it will in turn use to fit its internal decoder model. In future, this will hopefully become more integrated and easy. For now, some random calibration data is generated below to illustrate the format which the runner/decoder expects. You need to provide iid calibration data trials for each stimulus frequency.\n", 144 | "\n", 145 | "Note that if you try to run calibration using an incompatible algorithm (such as standard CCA), a warning will be generated and the calibration sequence will be skipped." 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 34, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "from lib.synthetic import synth_X\n", 155 | "\n", 156 | "calibration_data = {f:synth_X(f, Nc, Ns, Nt=Nt) for f in stim_freqs}\n", 157 | "runner.calibrate(calibration_data)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "### Decoding\n", 165 | "When configured with a set of stimulus frequencies $\\mathcal{F}=\\{f_1, \\dots, f_k, \\dots, f_K\\}$, the `Runner`'s decoder model consists of $K$ independent sub-classifiers $\\Phi_k$ that each leverage the decoding algorithm selected. These independent classifiers must be calibrated independently. When the `Runner` is presented a new test observation, each sub-classifier $\\Phi_k$ produces an output correlation estimate corresponding to $f_k$. Ultimately, the runner outputs a dictionary of frequency-correlation pairs of the form\n", 166 | "```python\n", 167 | "{f_1: 0.12, f_2: 0.03, f_3: 0.85}\n", 168 | "```\n", 169 | "The decoded output frequency is the one corresponding to the largest correlation in this output dictionary. In this example, it would be $f_3$." 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 36, 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "name": "stdout", 179 | "output_type": "stream", 180 | "text": [ 181 | "{12: 0.005366318394671077, 10: 0.0157398273859412, 7: 0.9957282993427281}\n" 182 | ] 183 | } 184 | ], 185 | "source": [ 186 | "test_freq = 7 # 7 Hz test signal\n", 187 | "test_data = synth_X(test_freq, Nc, Ns, Nt=1)\n", 188 | "\n", 189 | "print(runner.decoder.classify(test_data))" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "### Asynchronous operation\n", 197 | "Once the `Runner` has been configured and calibrated (if applicable), its internal `run()` loop can be started in which it will asynchronously sample and decode EEG data at preconfigured frequencies. Timing is handled using hardware timers on the ESP32 and interrupts are used to run asynchronous ISRs that handle sampling, preprocessing, filtering and decoding.\n", 198 | "\n", 199 | "Note that once the async run loop has begun, you can still run commands or view the `Runner`'s attributes although there may be a noticeable delay since ISRs will typically get higher execution priority and there are quite a few interrupt loops running." 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 25, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "# start sampling and recording data (logging not setup in this case)\n", 209 | "runner.run()" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 26, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "name": "stdout", 219 | "output_type": "stream", 220 | "text": [ 221 | "True\n" 222 | ] 223 | } 224 | ], 225 | "source": [ 226 | "# see if runner has indeed started smapling\n", 227 | "print(runner.is_sampling)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 27, 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "name": "stdout", 237 | "output_type": "stream", 238 | "text": [ 239 | "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n" 240 | ] 241 | } 242 | ], 243 | "source": [ 244 | "# display the contents of the output buffer - this will be updated internally by the runner\n", 245 | "# at a rate determined by the sampling frequency and sample buffer size (typically every 1s)\n", 246 | "print(runner.output_buffer)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 33, 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "name": "stdout", 256 | "output_type": "stream", 257 | "text": [ 258 | "{12: 0.09387, 10: 0.11474, 7: 0.05861999999999999}\n" 259 | ] 260 | } 261 | ], 262 | "source": [ 263 | "# decode the contents of the output buffer. There will be a delay here if the runner \n", 264 | "# is currently running (i.e. `is_sampling=True`).\n", 265 | "print(runner.decode())" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 30, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "# stop runner\n", 275 | "runner.stop()" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": {}, 281 | "source": [ 282 | "#### Simple decoding loop\n", 283 | "In order to test online decoding, here is a basic synchronous loop-based option. Interrupt the cell to stop the infinite loop." 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 2, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "ADC initialised\n", 296 | "SPI initialised\n", 297 | "DigiPot set to 100 = gain of 10.62497708393011\n", 298 | "{}\n", 299 | ".{12: 0.15038, 10: 0.02075, 7: 0.11349}\n", 300 | "{12: 0.15038, 10: 0.02075, 7: 0.11349}\n", 301 | ".{12: 0.04253, 10: 0.02158, 7: 0.00613}\n", 302 | "{12: 0.04253, 10: 0.02158, 7: 0.00613}\n", 303 | "{12: 0.05438, 10: 0.00615, 7: 0.08}\n", 304 | ".{12: 0.05438, 10: 0.00615, 7: 0.08}\n", 305 | "{12: 0.0278, 10: 0.00456, 7: 0.02368}\n", 306 | "\u001b[34m\n", 307 | "\n", 308 | "*** Sending Ctrl-C\n", 309 | "\n", 310 | "\u001b[0mreceived SIGINT - stopping\n" 311 | ] 312 | } 313 | ], 314 | "source": [ 315 | "import utime as time\n", 316 | "from lib.runner import Runner\n", 317 | "\n", 318 | "Nc = 1\n", 319 | "Ns = 128\n", 320 | "Nt = 3\n", 321 | "stim_freqs = [7, 10, 12]\n", 322 | "\n", 323 | "# Here, we select the algorithm. Can be one of ['MsetCCA', 'GCCA', 'CCA']\n", 324 | "decoding_algo = 'MsetCCA'\n", 325 | "\n", 326 | "decode_period_s = 2 # read decoded output every x seconds\n", 327 | "\n", 328 | "runner = Runner(decoding_algo, buffer_size=Ns) # initialise a base runner\n", 329 | "runner.setup()\n", 330 | "\n", 331 | "if decoding_algo in ['MsetCCA', 'GCCA']:\n", 332 | " from lib.synthetic import synth_X\n", 333 | "\n", 334 | " calibration_data = {f:synth_X(f, Nc, Ns, Nt=Nt) for f in stim_freqs}\n", 335 | " runner.calibrate(calibration_data)\n", 336 | "\n", 337 | "runner.run() # start async run loop\n", 338 | "\n", 339 | "try:\n", 340 | " while True:\n", 341 | " time.sleep(decode_period_s)\n", 342 | " print(runner.decoded_output)\n", 343 | "except KeyboardInterrupt:\n", 344 | " runner.stop()\n", 345 | " print('received SIGINT - stopping')" 346 | ] 347 | }, 348 | { 349 | "cell_type": "markdown", 350 | "metadata": {}, 351 | "source": [ 352 | "### Testing your WiFi connection\n", 353 | "In order to connect to a local WiFi network, you'll need to supply your network SSID and password in a `.env` file on the board. Doing this is easy: \n", 354 | "1. On your computer, create a `.env` file using `touch .env`. Update the `.env` file with the required fields:\n", 355 | " \n", 356 | " ```bash\n", 357 | " #.env \n", 358 | " WIFI_SSID=\n", 359 | " WIFI_PASSWORD=\n", 360 | " \n", 361 | " ```\n", 362 | " \n", 363 | "2. Send this file to your target device using the following command:\n", 364 | " ```ipython\n", 365 | "%sendtofile --source lib/.env lib/.env --binary\n", 366 | "```\n", 367 | "\n", 368 | "You may need to update the local (source) path to your `.env` file depending on where you created/stored it." 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "from lib.utils import connect_wifi, load_env_vars\n", 378 | "\n", 379 | "env_vars = load_env_vars(\"lib/.env\")\n", 380 | "# connect WiFI\n", 381 | "ssid = env_vars.get(\"WIFI_SSID\")\n", 382 | "password = env_vars.get(\"WIFI_PASSWORD\")\n", 383 | "connect_wifi(ssid, password)" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "metadata": {}, 389 | "source": [ 390 | "#### Online Runner\n", 391 | "Now that you've established network connectivitiy, you can test out an `OnlineRunner`. In order to test web logging to a remote server, we can use a basic HTTP logger. However, this obviously needs an API/server willing to accept our requests. There is a basic logging API using `Flask` in `/eeg_lib/logging_server.py`. You can run it using `python logging_server.py` which will spin up a development server on the predefined port (5000 or 5001). Then, just configure your `OnlineRunner` with the appropriate logger params and you're set." 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": 19, 397 | "metadata": {}, 398 | "outputs": [ 399 | { 400 | "name": "stdout", 401 | "output_type": "stream", 402 | "text": [ 403 | "ADC initialised\n", 404 | "SPI initialised\n", 405 | "DigiPot set to 100 = gain of 10.62498\n", 406 | "network config: ('192.168.0.28', '255.255.255.0', '192.168.0.1', '192.168.0.1')\n" 407 | ] 408 | } 409 | ], 410 | "source": [ 411 | "from lib.runner import OnlineRunner\n", 412 | "from lib.logging import logger_types\n", 413 | "\n", 414 | "api_host = \"http://192.168.0.2:5001/\" # make sure the port corresponds to your logging server configuration\n", 415 | "log_params = dict(server=api_host, log_period=4, logger_type=logger_types.HTTP, send_raw=True, session_id='test_session_1')\n", 416 | "\n", 417 | "runner = OnlineRunner()\n", 418 | "runner.setup(**log_params)" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": 20, 424 | "metadata": {}, 425 | "outputs": [], 426 | "source": [ 427 | "# start the runner - you should see requests being made to your local server\n", 428 | "runner.run()" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": 21, 434 | "metadata": {}, 435 | "outputs": [], 436 | "source": [ 437 | "runner.stop()" 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": {}, 443 | "source": [ 444 | "## Experimentation" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": 155, 450 | "metadata": {}, 451 | "outputs": [ 452 | { 453 | "name": "stdout", 454 | "output_type": "stream", 455 | "text": [ 456 | "repl is in normal command mode\n", 457 | "[\\r\\x03\\x03] b'\\r\\nMicroPython d8a7bf8-dirty on 2022-02-09; ESP32 module with ESP32\\r\\nType \"help()\" for more information.\\r\\n>>> \\r\\n>>> \\r\\nMPY: soft reboot\\r\\nMicroPython d8a7bf8-dirty on 2022-02-09; ESP32 module with ESP32\\r\\nType \"help()\" for more information.\\r\\n>>> \\r\\n>>> \\r\\n>>> '\n", 458 | "[\\r\\x01] b'\\r\\n>>> \\r\\nraw REPL; CTRL-B to exit\\r\\n>'" 459 | ] 460 | } 461 | ], 462 | "source": [ 463 | "%rebootdevice" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": 12, 469 | "metadata": {}, 470 | "outputs": [ 471 | { 472 | "name": "stdout", 473 | "output_type": "stream", 474 | "text": [ 475 | "%capture [--quiet] [--QUIET] outputfilename\n", 476 | " records output to a file\n", 477 | "\n", 478 | "%comment\n", 479 | " print this into output\n", 480 | "\n", 481 | "%disconnect [--raw]\n", 482 | " disconnects from web/serial connection\n", 483 | "\n", 484 | "%esptool [--port PORT] {erase,esp32,esp8266} [binfile]\n", 485 | " commands for flashing your esp-device\n", 486 | "\n", 487 | "%fetchfile [--binary] [--print] [--load] [--quiet] [--QUIET]\n", 488 | " sourcefilename [destinationfilename]\n", 489 | " fetch and save a file from the device\n", 490 | "\n", 491 | "%ls [--recurse] [dirname]\n", 492 | " list files on the device\n", 493 | "\n", 494 | "%lsmagic\n", 495 | " list magic commands\n", 496 | "\n", 497 | "%mpy-cross [--set-exe SET_EXE] [pyfile]\n", 498 | " cross-compile a .py file to a .mpy file\n", 499 | "\n", 500 | "%readbytes [--binary]\n", 501 | " does serial.read_all()\n", 502 | "\n", 503 | "%rebootdevice\n", 504 | " reboots device\n", 505 | "\n", 506 | "%sendtofile [--append] [--mkdir] [--binary] [--execute] [--source [SOURCE]] [--quiet]\n", 507 | " [--QUIET]\n", 508 | " [destinationfilename]\n", 509 | " send cell contents or file/direcectory to the device\n", 510 | "\n", 511 | "%serialconnect [--raw] [--port PORT] [--baud BAUD] [--verbose]\n", 512 | " connects to a device over USB wire\n", 513 | "\n", 514 | "%socketconnect [--raw] ipnumber portnumber\n", 515 | " connects to a socket of a device over wifi\n", 516 | "\n", 517 | "%suppressendcode\n", 518 | " doesn't send x04 or wait to read after sending the contents of the cell\n", 519 | " (assists for debugging using %writebytes and %readbytes)\n", 520 | "\n", 521 | "%websocketconnect [--raw] [--password PASSWORD] [--verbose] [websocketurl]\n", 522 | " connects to the webREPL websocket of an ESP8266 over wifi\n", 523 | " websocketurl defaults to ws://192.168.4.1:8266 but be sure to be connected\n", 524 | "\n", 525 | "%writebytes [--binary] [--verbose] stringtosend\n", 526 | " does serial.write() of the python quoted string given\n", 527 | "\n", 528 | "%%writefile [--append] [--execute] destinationfilename\n", 529 | " write contents of cell to a file\n", 530 | "\n" 531 | ] 532 | } 533 | ], 534 | "source": [ 535 | "%lsmagic" 536 | ] 537 | } 538 | ], 539 | "metadata": { 540 | "kernelspec": { 541 | "display_name": "MicroPython - USB", 542 | "language": "micropython", 543 | "name": "micropython" 544 | }, 545 | "language_info": { 546 | "codemirror_mode": "python", 547 | "file_extension": ".py", 548 | "mimetype": "text/python", 549 | "name": "micropython" 550 | } 551 | }, 552 | "nbformat": 4, 553 | "nbformat_minor": 4 554 | } 555 | -------------------------------------------------------------------------------- /micropython/mpy-modules/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JamesTev/EEG-decoding/c4054dd9d1eac857aedd487a34f177c97d95c0af/micropython/mpy-modules/.gitkeep -------------------------------------------------------------------------------- /misc/esp32-fft/esp32-fft.ino: -------------------------------------------------------------------------------- 1 | /* 2 | * Adapted from Robin Scheibler's ESP32 FFT lib. 3 | * 4 | * James Teversham 5 | * Imperial College London 6 | * 2021 7 | */ 8 | #include 9 | #include 10 | #include "freertos/FreeRTOS.h" 11 | #include "freertos/task.h" 12 | #include "sdkconfig.h" 13 | 14 | #include "soc/timer_group_struct.h" 15 | #include "driver/periph_ctrl.h" 16 | #include "driver/timer.h" 17 | 18 | extern "C"{ 19 | #include "fft.h" 20 | }; 21 | 22 | #define REP 100 23 | #define MIN_LOG_N 6 24 | #define MAX_LOG_N 12 25 | 26 | #define GPIO_OUTPUT 27 //27 27 | 28 | double start, end; 29 | const int ledPin = LED_BUILTIN;// the number of the LED pin 30 | 31 | timer_config_t timer_config = { 32 | .alarm_en = false, 33 | .counter_en = true, 34 | .intr_type = TIMER_INTR_LEVEL, 35 | .counter_dir = TIMER_COUNT_UP, 36 | .auto_reload = TIMER_AUTORELOAD_DIS, 37 | .divider = 80 /* 1 us per tick */ 38 | }; 39 | 40 | gpio_config_t gpio_conf = { 41 | //bit mask of the pins that you want to set,e.g.GPIO18/19 42 | .pin_bit_mask = (1 << gpio_num_t(GPIO_OUTPUT)), 43 | //set as output mode 44 | .mode = GPIO_MODE_OUTPUT, 45 | //disable pull-up mode 46 | .pull_up_en = GPIO_PULLUP_DISABLE, 47 | //disable pull-down mode 48 | .pull_down_en = GPIO_PULLDOWN_DISABLE, 49 | // disable interrupt 50 | .intr_type = GPIO_INTR_DISABLE 51 | }; 52 | 53 | void clock_init() 54 | { 55 | timer_init(TIMER_GROUP_0, TIMER_0, &timer_config); 56 | timer_set_counter_value(TIMER_GROUP_0, TIMER_0, 0); 57 | timer_start(TIMER_GROUP_0, TIMER_0); 58 | } 59 | 60 | void fft4_test_task() 61 | { 62 | int k; 63 | float input[8] = { 7, 8, 4, 4, 1, 1, 6, 8 }; 64 | float output[8]; 65 | float gt[8] = { 18., 21., 2., 9., -2., -3., 10., 5. }; 66 | 67 | fft4(input, 2, output, 2); 68 | 69 | printf("-----------\n"); 70 | for (k = 0 ; k < 8 ; k+=2) 71 | printf("%.2f%+.2fj ", output[k], output[k+1]); 72 | printf("\n"); 73 | for (k = 0 ; k < 8 ; k+=2) 74 | printf("%.2f%+.2fj ", gt[k], gt[k+1]); 75 | printf("\n"); 76 | printf("-----------\n"); 77 | } 78 | 79 | void fft8_test_task() 80 | { 81 | int k; 82 | float input[16] = { 7, 8, 4, 4, 1, 1, 6, 8, 1, 1, 9, 6, 0, 8, 7, 4 }; 83 | float output[16]; 84 | float gt[16] = { 35., 40., -2.41421356, 6., 5.00000000, 0., 17.24264069, 16.48528137, -17., -4., 0.41421356, 6., 9.00000000, 0., 8.75735931, -0.48528137 }; 85 | 86 | fft8(input, 2, output, 2); 87 | 88 | printf("-----------\n"); 89 | for (k = 0 ; k < 16 ; k+=2) 90 | printf("%.2f+%.2fj ", output[k], output[k+1]); 91 | printf("\n"); 92 | for (k = 0 ; k < 16 ; k+=2) 93 | printf("%.2f+%.2fj ", gt[k], gt[k+1]); 94 | printf("\n"); 95 | printf("-----------\n"); 96 | } 97 | 98 | void fft_test_task() 99 | { 100 | int k, n; 101 | 102 | for (n = MIN_LOG_N ; n <= MAX_LOG_N ; n++) 103 | { 104 | int NFFT = 1 << n; 105 | 106 | // Create fft plan and let it allocate arrays 107 | fft_config_t *fft_analysis = fft_init(NFFT, FFT_COMPLEX, FFT_FORWARD, NULL, NULL); 108 | fft_config_t *fft_synthesis = fft_init(NFFT, FFT_COMPLEX, FFT_BACKWARD, fft_analysis->output, NULL); 109 | 110 | // Fill array with some dummy data 111 | for (k = 0 ; k < fft_analysis->size ; k++) 112 | { 113 | fft_analysis->input[2*k] = (float)k / (float)fft_analysis->size; 114 | fft_analysis->input[2*k+1] = (float)(k-1) / (float)fft_analysis->size; 115 | } 116 | 117 | // Test accuracy 118 | fft_execute(fft_analysis); 119 | fft_execute(fft_synthesis); 120 | 121 | int n_errors = 0; 122 | for (k = 0 ; k < 2 * fft_analysis->size ; k++) 123 | if (abs(fft_analysis->input[k] - fft_synthesis->output[k]) > 1e-5) 124 | { 125 | printf("bin=%d input=%.4f output=%.4f\n err=%f", 126 | k, fft_analysis->input[k], fft_synthesis->output[k], 127 | fabsf(fft_analysis->input[k] - fft_synthesis->output[k])); 128 | n_errors++; 129 | } 130 | if (n_errors == 0) 131 | printf("Transform seems to work!\n"); 132 | 133 | // Now measure execution time 134 | timer_get_counter_time_sec(TIMER_GROUP_0, TIMER_0, &start); 135 | gpio_set_level(gpio_num_t(GPIO_OUTPUT), 1); 136 | for (k = 0 ; k < REP ; k++) 137 | fft_execute(fft_analysis); 138 | gpio_set_level(gpio_num_t(GPIO_OUTPUT), 0); 139 | timer_get_counter_time_sec(TIMER_GROUP_0, TIMER_0, &end); 140 | printf(" FFT size=%d runtime=%f ms\n", NFFT, 1000 * (end - start) / REP); 141 | 142 | vTaskDelay(10 / portTICK_RATE_MS); 143 | 144 | timer_get_counter_time_sec(TIMER_GROUP_0, TIMER_0, &start); 145 | gpio_set_level(gpio_num_t(GPIO_OUTPUT), 1); 146 | for (k = 0 ; k < REP ; k++) 147 | fft_execute(fft_synthesis); 148 | gpio_set_level(gpio_num_t(GPIO_OUTPUT), 0); 149 | timer_get_counter_time_sec(TIMER_GROUP_0, TIMER_0, &end); 150 | printf("iFFT size=%d runtime=%f ms\n", NFFT, 1000 * (end - start) / REP); 151 | 152 | fft_destroy(fft_analysis); 153 | fft_destroy(fft_synthesis); 154 | } 155 | } 156 | 157 | void rfft_test_task() 158 | { 159 | int k, n; 160 | 161 | for (n = MIN_LOG_N ; n <= MAX_LOG_N ; n++) 162 | { 163 | int NFFT = 1 << n; 164 | 165 | // Create fft plan and let it allocate arrays 166 | fft_config_t *fft_analysis = fft_init(NFFT, FFT_REAL, FFT_FORWARD, NULL, NULL); 167 | fft_config_t *fft_synthesis = fft_init(NFFT, FFT_REAL, FFT_BACKWARD, fft_analysis->output, NULL); 168 | 169 | // Fill array with some dummy data 170 | for (k = 0 ; k < fft_analysis->size ; k++) 171 | fft_analysis->input[k] = (float)k / (float)fft_analysis->size; 172 | 173 | // Test accuracy 174 | fft_execute(fft_analysis); 175 | fft_execute(fft_synthesis); 176 | 177 | int n_errors = 0; 178 | for (k = 0 ; k < fft_analysis->size ; k++) 179 | if (abs(fft_analysis->input[k] - fft_synthesis->output[k]) > 1e-5) 180 | { 181 | printf("bin=%d input=%.4f output=%.4f\n err=%f", 182 | k, fft_analysis->input[k], fft_synthesis->output[k], 183 | fabsf(fft_analysis->input[k] - fft_synthesis->output[k])); 184 | n_errors++; 185 | } 186 | if (n_errors == 0) 187 | printf("Transform seems to work!\n"); 188 | 189 | // Now measure execution time 190 | timer_get_counter_time_sec(TIMER_GROUP_0, TIMER_0, &start); 191 | gpio_set_level(gpio_num_t(GPIO_OUTPUT), 1); 192 | for (k = 0 ; k < REP ; k++) 193 | fft_execute(fft_analysis); 194 | gpio_set_level(gpio_num_t(GPIO_OUTPUT), 0); 195 | timer_get_counter_time_sec(TIMER_GROUP_0, TIMER_0, &end); 196 | printf(" Real FFT size=%d runtime=%f ms\n", NFFT, 1000 * (end - start) / REP); 197 | 198 | vTaskDelay(10 / portTICK_RATE_MS); 199 | 200 | timer_get_counter_time_sec(TIMER_GROUP_0, TIMER_0, &start); 201 | gpio_set_level(gpio_num_t(GPIO_OUTPUT), 1); 202 | for (k = 0 ; k < REP ; k++) 203 | fft_execute(fft_synthesis); 204 | gpio_set_level(gpio_num_t(GPIO_OUTPUT), 0); 205 | timer_get_counter_time_sec(TIMER_GROUP_0, TIMER_0, &end); 206 | printf("Real iFFT size=%d runtime=%f ms\n", NFFT, 1000 * (end - start) / REP); 207 | 208 | fft_destroy(fft_analysis); 209 | fft_destroy(fft_synthesis); 210 | } 211 | } 212 | void setup() { 213 | // put your setup code here, to run once: 214 | gpio_config(&gpio_conf); 215 | gpio_set_level(gpio_num_t(GPIO_OUTPUT), 0); 216 | pinMode(ledPin, OUTPUT); 217 | clock_init(); 218 | } 219 | 220 | void loop() { 221 | // put your main code here, to run repeatedly: 222 | fft_test_task(); 223 | rfft_test_task(); 224 | digitalWrite(ledPin, !digitalRead(ledPin)); 225 | //fft8_test_task(); 226 | //fft4_test_task(); 227 | vTaskDelay(1000 / portTICK_RATE_MS); 228 | } 229 | -------------------------------------------------------------------------------- /misc/esp32-fft/fft.c: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | ESP32 FFT 4 | ========= 5 | 6 | This provides a vanilla radix-2 FFT implementation and a test example. 7 | 8 | Author 9 | ------ 10 | 11 | This code was written by [Robin Scheibler](http://www.robinscheibler.org) during rainy days in October 2017. 12 | 13 | License 14 | ------- 15 | 16 | Copyright (c) 2017 Robin Scheibler 17 | 18 | Permission is hereby granted, free of charge, to any person obtaining a copy 19 | of this software and associated documentation files (the "Software"), to deal 20 | in the Software without restriction, including without limitation the rights 21 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 22 | copies of the Software, and to permit persons to whom the Software is 23 | furnished to do so, subject to the following conditions: 24 | 25 | The above copyright notice and this permission notice shall be included in all 26 | copies or substantial portions of the Software. 27 | 28 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 29 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 30 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 31 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 32 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 33 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 34 | SOFTWARE. 35 | 36 | */ 37 | #include 38 | #include 39 | #include 40 | #include 41 | 42 | #include "fft.h" 43 | 44 | #define TWO_PI 6.28318530 45 | #define USE_SPLIT_RADIX 1 46 | #define LARGE_BASE_CASE 1 47 | 48 | fft_config_t *fft_init(int size, fft_type_t type, fft_direction_t direction, float *input, float *output) 49 | { 50 | /* 51 | * Prepare an FFT of correct size and types. 52 | * 53 | * If no input or output buffers are provided, they will be allocated. 54 | */ 55 | int k,m; 56 | 57 | fft_config_t *config = (fft_config_t *)malloc(sizeof(fft_config_t)); 58 | 59 | // Check if the size is a power of two 60 | if ((size & (size-1)) != 0) // tests if size is a power of two 61 | return NULL; 62 | 63 | // start configuration 64 | config->flags = 0; 65 | config->type = type; 66 | config->direction = direction; 67 | config->size = size; 68 | 69 | // Allocate and precompute twiddle factors 70 | config->twiddle_factors = (float *)malloc(2 * config->size * sizeof(float)); 71 | 72 | float two_pi_by_n = TWO_PI / config->size; 73 | 74 | for (k = 0, m = 0 ; k < config->size ; k++, m+=2) 75 | { 76 | config->twiddle_factors[m] = cosf(two_pi_by_n * k); // real 77 | config->twiddle_factors[m+1] = sinf(two_pi_by_n * k); // imag 78 | } 79 | 80 | // Allocate input buffer 81 | if (input != NULL) 82 | config->input = input; 83 | else 84 | { 85 | if (config->type == FFT_REAL) 86 | config->input = (float *)malloc(config->size * sizeof(float)); 87 | else if (config->type == FFT_COMPLEX) 88 | config->input = (float *)malloc(2 * config->size * sizeof(float)); 89 | 90 | config->flags |= FFT_OWN_INPUT_MEM; 91 | } 92 | 93 | if (config->input == NULL) 94 | return NULL; 95 | 96 | // Allocate output buffer 97 | if (output != NULL) 98 | config->output = output; 99 | else 100 | { 101 | if (config->type == FFT_REAL) 102 | config->output = (float *)malloc(config->size * sizeof(float)); 103 | else if (config->type == FFT_COMPLEX) 104 | config->output = (float *)malloc(2 * config->size * sizeof(float)); 105 | 106 | config->flags |= FFT_OWN_OUTPUT_MEM; 107 | } 108 | 109 | if (config->output == NULL) 110 | return NULL; 111 | 112 | return config; 113 | } 114 | 115 | void fft_destroy(fft_config_t *config) 116 | { 117 | if (config->flags & FFT_OWN_INPUT_MEM) 118 | free(config->input); 119 | 120 | if (config->flags & FFT_OWN_OUTPUT_MEM) 121 | free(config->output); 122 | 123 | free(config->twiddle_factors); 124 | free(config); 125 | } 126 | 127 | void fft_execute(fft_config_t *config) 128 | { 129 | if (config->type == FFT_REAL && config->direction == FFT_FORWARD) 130 | rfft(config->input, config->output, config->twiddle_factors, config->size); 131 | else if (config->type == FFT_REAL && config->direction == FFT_BACKWARD) 132 | irfft(config->input, config->output, config->twiddle_factors, config->size); 133 | else if (config->type == FFT_COMPLEX && config->direction == FFT_FORWARD) 134 | fft(config->input, config->output, config->twiddle_factors, config->size); 135 | else if (config->type == FFT_COMPLEX && config->direction == FFT_BACKWARD) 136 | ifft(config->input, config->output, config->twiddle_factors, config->size); 137 | } 138 | 139 | void fft(float *input, float *output, float *twiddle_factors, int n) 140 | { 141 | /* 142 | * Forward fast Fourier transform 143 | * DIT, radix-2, out-of-place implementation 144 | * 145 | * Parameters 146 | * ---------- 147 | * input (float *) 148 | * The input array containing the complex samples with 149 | * real/imaginary parts interleaved [Re(x0), Im(x0), ..., Re(x_n-1), Im(x_n-1)] 150 | * output (float *) 151 | * The output array containing the complex samples with 152 | * real/imaginary parts interleaved [Re(x0), Im(x0), ..., Re(x_n-1), Im(x_n-1)] 153 | * n (int) 154 | * The FFT size, should be a power of 2 155 | */ 156 | 157 | #if USE_SPLIT_RADIX 158 | split_radix_fft(input, output, n, 2, twiddle_factors, 2); 159 | #else 160 | fft_primitive(input, output, n, 2, twiddle_factors, 2); 161 | #endif 162 | } 163 | 164 | void ifft(float *input, float *output, float *twiddle_factors, int n) 165 | { 166 | /* 167 | * Inverse fast Fourier transform 168 | * DIT, radix-2, out-of-place implementation 169 | * 170 | * Parameters 171 | * ---------- 172 | * input (float *) 173 | * The input array containing the complex samples with 174 | * real/imaginary parts interleaved [Re(x0), Im(x0), ..., Re(x_n-1), Im(x_n-1)] 175 | * output (float *) 176 | * The output array containing the complex samples with 177 | * real/imaginary parts interleaved [Re(x0), Im(x0), ..., Re(x_n-1), Im(x_n-1)] 178 | * n (int) 179 | * The FFT size, should be a power of 2 180 | */ 181 | ifft_primitive(input, output, n, 2, twiddle_factors, 2); 182 | } 183 | 184 | void rfft(float *x, float *y, float *twiddle_factors, int n) 185 | { 186 | 187 | // This code uses the two-for-the-price-of-one strategy 188 | #if USE_SPLIT_RADIX 189 | split_radix_fft(x, y, n / 2, 2, twiddle_factors, 4); 190 | #else 191 | fft_primitive(x, y, n / 2, 2, twiddle_factors, 4); 192 | #endif 193 | 194 | // Now apply post processing to recover positive 195 | // frequencies of the real FFT 196 | float t = y[0]; 197 | y[0] = t + y[1]; // DC coefficient 198 | y[1] = t - y[1]; // Center coefficient 199 | 200 | // Apply post processing to quarter element 201 | // this boils down to taking complex conjugate 202 | y[n/2+1] = -y[n/2+1]; 203 | 204 | // Now process all the other frequencies 205 | int k; 206 | for (k = 2 ; k < n / 2 ; k += 2) 207 | { 208 | float xer, xei, xor, xoi, c, s, tr, ti; 209 | 210 | c = twiddle_factors[k]; 211 | s = twiddle_factors[k+1]; 212 | 213 | // even half coefficient 214 | xer = 0.5 * (y[k] + y[n-k]); 215 | xei = 0.5 * (y[k+1] - y[n-k+1]); 216 | 217 | // odd half coefficient 218 | xor = 0.5 * (y[k+1] + y[n-k+1]); 219 | xoi = - 0.5 * (y[k] - y[n-k]); 220 | 221 | tr = c * xor + s * xoi; 222 | ti = -s * xor + c * xoi; 223 | 224 | y[k] = xer + tr; 225 | y[k+1] = xei + ti; 226 | 227 | y[n-k] = xer - tr; 228 | y[n-k+1] = -(xei - ti); 229 | } 230 | } 231 | 232 | void irfft(float *x, float *y, float *twiddle_factors, int n) 233 | { 234 | /* 235 | * Destroys content of input vector 236 | */ 237 | int k; 238 | 239 | // Here we need to apply a pre-processing first 240 | float t = x[0]; 241 | x[0] = 0.5 * (t + x[1]); 242 | x[1] = 0.5 * (t - x[1]); 243 | 244 | x[n/2+1] = -x[n/2+1]; 245 | 246 | for (k = 2 ; k < n / 2 ; k += 2) 247 | { 248 | float xer, xei, xor, xoi, c, s, tr, ti; 249 | 250 | c = twiddle_factors[k]; 251 | s = twiddle_factors[k+1]; 252 | 253 | xer = 0.5 * (x[k] + x[n-k]); 254 | tr = 0.5 * (x[k] - x[n-k]); 255 | 256 | xei = 0.5 * (x[k+1] - x[n-k+1]); 257 | ti = 0.5 * (x[k+1] + x[n-k+1]); 258 | 259 | xor = c * tr - s * ti; 260 | xoi = s * tr + c * ti; 261 | 262 | x[k] = xer - xoi; 263 | x[k+1] = xor + xei; 264 | 265 | x[n-k] = xer + xoi; 266 | x[n-k+1] = xor - xei; 267 | } 268 | 269 | ifft_primitive(x, y, n / 2, 2, twiddle_factors, 4); 270 | } 271 | 272 | void fft_primitive(float *x, float *y, int n, int stride, float *twiddle_factors, int tw_stride) 273 | { 274 | /* 275 | * This code will compute the FFT of the input vector x 276 | * 277 | * The input data is assumed to be real/imag interleaved 278 | * 279 | * The size n should be a power of two 280 | * 281 | * y is an output buffer of size 2n to accomodate for complex numbers 282 | * 283 | * Forward fast Fourier transform 284 | * DIT, radix-2, out-of-place implementation 285 | * 286 | * For a complex FFT, call first stage as: 287 | * fft(x, y, n, 2, 2); 288 | * 289 | * Parameters 290 | * ---------- 291 | * x (float *) 292 | * The input array containing the complex samples with 293 | * real/imaginary parts interleaved [Re(x0), Im(x0), ..., Re(x_n-1), Im(x_n-1)] 294 | * y (float *) 295 | * The output array containing the complex samples with 296 | * real/imaginary parts interleaved [Re(x0), Im(x0), ..., Re(x_n-1), Im(x_n-1)] 297 | * n (int) 298 | * The FFT size, should be a power of 2 299 | * stride (int) 300 | * The number of elements to skip between two successive samples 301 | * tw_stride (int) 302 | * The number of elements to skip between two successive twiddle factors 303 | */ 304 | int k; 305 | float t; 306 | 307 | #if LARGE_BASE_CASE 308 | // End condition, stop at n=8 to avoid one trivial recursion 309 | if (n == 8) 310 | { 311 | fft8(x, stride, y, 2); 312 | return; 313 | } 314 | #else 315 | // End condition, stop at n=2 to avoid one trivial recursion 316 | if (n == 2) 317 | { 318 | y[0] = x[0] + x[stride]; 319 | y[1] = x[1] + x[stride + 1]; 320 | y[2] = x[0] - x[stride]; 321 | y[3] = x[1] - x[stride + 1]; 322 | return; 323 | } 324 | #endif 325 | 326 | // Recursion -- Decimation In Time algorithm 327 | fft_primitive(x, y, n / 2, 2 * stride, twiddle_factors, 2 * tw_stride); // even half 328 | fft_primitive(x + stride, y+n, n / 2, 2 * stride, twiddle_factors, 2 * tw_stride); // odd half 329 | 330 | // Stitch back together 331 | 332 | // We can a few multiplications in the first step 333 | t = y[0]; 334 | y[0] = t + y[n]; 335 | y[n] = t - y[n]; 336 | 337 | t = y[1]; 338 | y[1] = t + y[n+1]; 339 | y[n+1] = t - y[n+1]; 340 | 341 | for (k = 1 ; k < n / 2 ; k++) 342 | { 343 | float x1r, x1i, x2r, x2i, c, s; 344 | c = twiddle_factors[k * tw_stride]; 345 | s = twiddle_factors[k * tw_stride + 1]; 346 | 347 | x1r = y[2 * k]; 348 | x1i = y[2 * k + 1]; 349 | x2r = c * y[n + 2 * k] + s * y[n + 2 * k + 1]; 350 | x2i = -s * y[n + 2 * k] + c * y[n + 2 * k + 1]; 351 | 352 | y[2 * k] = x1r + x2r; 353 | y[2 * k + 1] = x1i + x2i; 354 | 355 | y[n + 2 * k] = x1r - x2r; 356 | y[n + 2 * k + 1] = x1i - x2i; 357 | } 358 | 359 | } 360 | 361 | void split_radix_fft(float *x, float *y, int n, int stride, float *twiddle_factors, int tw_stride) 362 | { 363 | /* 364 | * This code will compute the FFT of the input vector x 365 | * 366 | * The input data is assumed to be real/imag interleaved 367 | * 368 | * The size n should be a power of two 369 | * 370 | * y is an output buffer of size 2n to accomodate for complex numbers 371 | * 372 | * Forward fast Fourier transform 373 | * Split-Radix 374 | * DIT, radix-2, out-of-place implementation 375 | * 376 | * For a complex FFT, call first stage as: 377 | * fft(x, y, n, 2, 2); 378 | * 379 | * Parameters 380 | * ---------- 381 | * x (float *) 382 | * The input array containing the complex samples with 383 | * real/imaginary parts interleaved [Re(x0), Im(x0), ..., Re(x_n-1), Im(x_n-1)] 384 | * y (float *) 385 | * The output array containing the complex samples with 386 | * real/imaginary parts interleaved [Re(x0), Im(x0), ..., Re(x_n-1), Im(x_n-1)] 387 | * n (int) 388 | * The FFT size, should be a power of 2 389 | * stride (int) 390 | * The number of elements to skip between two successive samples 391 | * twiddle_factors (float *) 392 | * The array of twiddle factors 393 | * tw_stride (int) 394 | * The number of elements to skip between two successive twiddle factors 395 | */ 396 | int k; 397 | 398 | #if LARGE_BASE_CASE 399 | // End condition, stop at n=2 to avoid one trivial recursion 400 | if (n == 8) 401 | { 402 | fft8(x, stride, y, 2); 403 | return; 404 | } 405 | else if (n == 4) 406 | { 407 | fft4(x, stride, y, 2); 408 | return; 409 | } 410 | #else 411 | // End condition, stop at n=2 to avoid one trivial recursion 412 | if (n == 2) 413 | { 414 | y[0] = x[0] + x[stride]; 415 | y[1] = x[1] + x[stride + 1]; 416 | y[2] = x[0] - x[stride]; 417 | y[3] = x[1] - x[stride + 1]; 418 | return; 419 | } 420 | else if (n == 1) 421 | { 422 | y[0] = x[0]; 423 | y[1] = x[1]; 424 | return; 425 | } 426 | #endif 427 | 428 | // Recursion -- Decimation In Time algorithm 429 | split_radix_fft(x, y, n / 2, 2 * stride, twiddle_factors, 2 * tw_stride); 430 | split_radix_fft(x + stride, y + n, n / 4, 4 * stride, twiddle_factors, 4 * tw_stride); 431 | split_radix_fft(x + 3 * stride, y + n + n / 2, n / 4, 4 * stride, twiddle_factors, 4 * tw_stride); 432 | 433 | // Stitch together the output 434 | float u1r, u1i, u2r, u2i, x1r, x1i, x2r, x2i; 435 | float t; 436 | 437 | // We can save a few multiplications in the first step 438 | u1r = y[0]; 439 | u1i = y[1]; 440 | u2r = y[n / 2]; 441 | u2i = y[n / 2 + 1]; 442 | 443 | x1r = y[n]; 444 | x1i = y[n + 1]; 445 | x2r = y[n / 2 + n]; 446 | x2i = y[n / 2 + n + 1]; 447 | 448 | t = x1r + x2r; 449 | y[0] = u1r + t; 450 | y[n] = u1r - t; 451 | 452 | t = x1i + x2i; 453 | y[1] = u1i + t; 454 | y[n + 1] = u1i - t; 455 | 456 | t = x2i - x1i; 457 | y[n / 2] = u2r - t; 458 | y[n + n / 2] = u2r + t; 459 | 460 | t = x1r - x2r; 461 | y[n / 2 + 1] = u2i - t; 462 | y[n + n / 2 + 1] = u2i + t; 463 | 464 | for (k = 1 ; k < n / 4 ; k++) 465 | { 466 | float u1r, u1i, u2r, u2i, x1r, x1i, x2r, x2i, c1, s1, c2, s2; 467 | c1 = twiddle_factors[k * tw_stride]; 468 | s1 = twiddle_factors[k * tw_stride + 1]; 469 | c2 = twiddle_factors[3 * k * tw_stride]; 470 | s2 = twiddle_factors[3 * k * tw_stride + 1]; 471 | 472 | u1r = y[2 * k]; 473 | u1i = y[2 * k + 1]; 474 | u2r = y[2 * k + n / 2]; 475 | u2i = y[2 * k + n / 2 + 1]; 476 | 477 | x1r = c1 * y[n + 2 * k] + s1 * y[n + 2 * k + 1]; 478 | x1i = -s1 * y[n + 2 * k] + c1 * y[n + 2 * k + 1]; 479 | x2r = c2 * y[n / 2 + n + 2 * k] + s2 * y[n / 2 + n + 2 * k + 1]; 480 | x2i = -s2 * y[n / 2 + n + 2 * k] + c2 * y[n / 2 + n + 2 * k + 1]; 481 | 482 | t = x1r + x2r; 483 | y[2 * k] = u1r + t; 484 | y[2 * k + n] = u1r - t; 485 | 486 | t = x1i + x2i; 487 | y[2 * k + 1] = u1i + t; 488 | y[2 * k + n + 1] = u1i - t; 489 | 490 | t = x2i - x1i; 491 | y[2 * k + n / 2] = u2r - t; 492 | y[2 * k + n + n / 2] = u2r + t; 493 | 494 | t = x1r - x2r; 495 | y[2 * k + n / 2 + 1] = u2i - t; 496 | y[2 * k + n + n / 2 + 1] = u2i + t; 497 | } 498 | 499 | } 500 | 501 | 502 | void ifft_primitive(float *input, float *output, int n, int stride, float *twiddle_factors, int tw_stride) 503 | { 504 | 505 | #if USE_SPLIT_RADIX 506 | split_radix_fft(input, output, n, stride, twiddle_factors, tw_stride); 507 | #else 508 | fft_primitive(input, output, n, stride, twiddle_factors, tw_stride); 509 | #endif 510 | 511 | int ks; 512 | 513 | int ns = n * stride; 514 | 515 | // reverse all coefficients from 1 to n / 2 - 1 516 | for (ks = stride ; ks < ns / 2 ; ks += stride) 517 | { 518 | float t; 519 | 520 | t = output[ks]; 521 | output[ks] = output[ns-ks]; 522 | output[ns-ks] = t; 523 | 524 | t = output[ks+1]; 525 | output[ks+1] = output[ns-ks+1]; 526 | output[ns-ks+1] = t; 527 | } 528 | 529 | // Apply normalization 530 | float norm = 1. / n; 531 | for (ks = 0 ; ks < ns ; ks += stride) 532 | { 533 | output[ks] *= norm; 534 | output[ks+1] *= norm; 535 | } 536 | 537 | } 538 | 539 | inline void fft8(float *input, int stride_in, float *output, int stride_out) 540 | { 541 | /* 542 | * Unrolled implementation of FFT8 for a little more performance 543 | */ 544 | float a0r, a1r, a2r, a3r, a4r, a5r, a6r, a7r; 545 | float a0i, a1i, a2i, a3i, a4i, a5i, a6i, a7i; 546 | float b0r, b1r, b2r, b3r, b4r, b5r, b6r, b7r; 547 | float b0i, b1i, b2i, b3i, b4i, b5i, b6i, b7i; 548 | float t; 549 | float sin_pi_4 = 0.7071067812; 550 | 551 | a0r = input[0]; 552 | a0i = input[1]; 553 | a1r = input[stride_in]; 554 | a1i = input[stride_in+1]; 555 | a2r = input[2*stride_in]; 556 | a2i = input[2*stride_in+1]; 557 | a3r = input[3*stride_in]; 558 | a3i = input[3*stride_in+1]; 559 | a4r = input[4*stride_in]; 560 | a4i = input[4*stride_in+1]; 561 | a5r = input[5*stride_in]; 562 | a5i = input[5*stride_in+1]; 563 | a6r = input[6*stride_in]; 564 | a6i = input[6*stride_in+1]; 565 | a7r = input[7*stride_in]; 566 | a7i = input[7*stride_in+1]; 567 | 568 | // Stage 1 569 | 570 | b0r = a0r + a4r; 571 | b0i = a0i + a4i; 572 | 573 | b1r = a1r + a5r; 574 | b1i = a1i + a5i; 575 | 576 | b2r = a2r + a6r; 577 | b2i = a2i + a6i; 578 | 579 | b3r = a3r + a7r; 580 | b3i = a3i + a7i; 581 | 582 | b4r = a0r - a4r; 583 | b4i = a0i - a4i; 584 | 585 | b5r = a1r - a5r; 586 | b5i = a1i - a5i; 587 | // W_8^1 = 1/sqrt(2) - j / sqrt(2) 588 | t = b5r + b5i; 589 | b5i = (b5i - b5r) * sin_pi_4; 590 | b5r = t * sin_pi_4; 591 | 592 | // W_8^2 = -j 593 | b6r = a2i - a6i; 594 | b6i = a6r - a2r; 595 | 596 | b7r = a3r - a7r; 597 | b7i = a3i - a7i; 598 | // W_8^3 = -1 / sqrt(2) + j / sqrt(2) 599 | t = sin_pi_4 * (b7i - b7r); 600 | b7i = - (b7r + b7i) * sin_pi_4; 601 | b7r = t; 602 | 603 | // Stage 2 604 | 605 | a0r = b0r + b2r; 606 | a0i = b0i + b2i; 607 | 608 | a1r = b1r + b3r; 609 | a1i = b1i + b3i; 610 | 611 | a2r = b0r - b2r; 612 | a2i = b0i - b2i; 613 | 614 | // * j 615 | a3r = b1i - b3i; 616 | a3i = b3r - b1r; 617 | 618 | a4r = b4r + b6r; 619 | a4i = b4i + b6i; 620 | 621 | a5r = b5r + b7r; 622 | a5i = b5i + b7i; 623 | 624 | a6r = b4r - b6r; 625 | a6i = b4i - b6i; 626 | 627 | // * j 628 | a7r = b5i - b7i; 629 | a7i = b7r - b5r; 630 | 631 | // Stage 3 632 | 633 | // X[0] 634 | output[0] = a0r + a1r; 635 | output[1] = a0i + a1i; 636 | 637 | // X[4] 638 | output[4*stride_out] = a0r - a1r; 639 | output[4*stride_out+1] = a0i - a1i; 640 | 641 | // X[2] 642 | output[2*stride_out] = a2r + a3r; 643 | output[2*stride_out+1] = a2i + a3i; 644 | 645 | // X[6] 646 | output[6*stride_out] = a2r - a3r; 647 | output[6*stride_out+1] = a2i - a3i; 648 | 649 | // X[1] 650 | output[stride_out] = a4r + a5r; 651 | output[stride_out+1] = a4i + a5i; 652 | 653 | // X[5] 654 | output[5*stride_out] = a4r - a5r; 655 | output[5*stride_out+1] = a4i - a5i; 656 | 657 | // X[3] 658 | output[3*stride_out] = a6r + a7r; 659 | output[3*stride_out+1] = a6i + a7i; 660 | 661 | // X[7] 662 | output[7*stride_out] = a6r - a7r; 663 | output[7*stride_out+1] = a6i - a7i; 664 | 665 | } 666 | 667 | inline void fft4(float *input, int stride_in, float *output, int stride_out) 668 | { 669 | /* 670 | * Unrolled implementation of FFT4 for a little more performance 671 | */ 672 | float t1, t2; 673 | 674 | t1 = input[0] + input[2*stride_in]; 675 | t2 = input[stride_in] + input[3*stride_in]; 676 | output[0] = t1 + t2; 677 | output[2*stride_out] = t1 - t2; 678 | 679 | t1 = input[1] + input[2*stride_in+1]; 680 | t2 = input[stride_in+1] + input[3*stride_in+1]; 681 | output[1] = t1 + t2; 682 | output[2*stride_out+1] = t1 - t2; 683 | 684 | t1 = input[0] - input[2*stride_in]; 685 | t2 = input[stride_in+1] - input[3*stride_in+1]; 686 | output[stride_out] = t1 + t2; 687 | output[3*stride_out] = t1 - t2; 688 | 689 | t1 = input[1] - input[2*stride_in+1]; 690 | t2 = input[3*stride_in] - input[stride_in]; 691 | output[stride_out+1] = t1 + t2; 692 | output[3*stride_out+1] = t1 - t2; 693 | } 694 | -------------------------------------------------------------------------------- /misc/esp32-fft/fft.h: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | ESP32 FFT 4 | ========= 5 | 6 | This provides a vanilla radix-2 FFT implementation and a test example. 7 | 8 | Author 9 | ------ 10 | 11 | This code was written by [Robin Scheibler](http://www.robinscheibler.org) during rainy days in October 2017. 12 | 13 | License 14 | ------- 15 | 16 | Copyright (c) 2017 Robin Scheibler 17 | 18 | Permission is hereby granted, free of charge, to any person obtaining a copy 19 | of this software and associated documentation files (the "Software"), to deal 20 | in the Software without restriction, including without limitation the rights 21 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 22 | copies of the Software, and to permit persons to whom the Software is 23 | furnished to do so, subject to the following conditions: 24 | 25 | The above copyright notice and this permission notice shall be included in all 26 | copies or substantial portions of the Software. 27 | 28 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 29 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 30 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 31 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 32 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 33 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 34 | SOFTWARE. 35 | 36 | */ 37 | #pragma once 38 | 39 | #ifndef __FFT_H__ 40 | #define __FFT_H__ 41 | 42 | typedef enum 43 | { 44 | FFT_REAL, 45 | FFT_COMPLEX 46 | } fft_type_t; 47 | 48 | typedef enum 49 | { 50 | FFT_FORWARD, 51 | FFT_BACKWARD 52 | } fft_direction_t; 53 | 54 | #define FFT_OWN_INPUT_MEM 1 55 | #define FFT_OWN_OUTPUT_MEM 2 56 | 57 | typedef struct 58 | { 59 | int size; // FFT size 60 | float *input; // pointer to input buffer 61 | float *output; // pointer to output buffer 62 | float *twiddle_factors; // pointer to buffer holding twiddle factors 63 | fft_type_t type; // real or complex 64 | fft_direction_t direction; // forward or backward 65 | unsigned int flags; // FFT flags 66 | } fft_config_t; 67 | 68 | fft_config_t *fft_init(int size, fft_type_t type, fft_direction_t direction, float *input, float *output); 69 | void fft_destroy(fft_config_t *config); 70 | void fft_execute(fft_config_t *config); 71 | void fft(float *input, float *output, float *twiddle_factors, int n); 72 | void ifft(float *input, float *output, float *twiddle_factors, int n); 73 | void rfft(float *x, float *y, float *twiddle_factors, int n); 74 | void irfft(float *x, float *y, float *twiddle_factors, int n); 75 | void fft_primitive(float *x, float *y, int n, int stride, float *twiddle_factors, int tw_stride); 76 | void split_radix_fft(float *x, float *y, int n, int stride, float *twiddle_factors, int tw_stride); 77 | void ifft_primitive(float *input, float *output, int n, int stride, float *twiddle_factors, int tw_stride); 78 | void fft8(float *input, int stride_in, float *output, int stride_out); 79 | void fft4(float *input, int stride_in, float *output, int stride_out); 80 | 81 | #endif // __FFT_H__ 82 | -------------------------------------------------------------------------------- /misc/esp32-fft/svd.c: -------------------------------------------------------------------------------- 1 | /* svd.c: Perform a singular value decomposition A = USV' of square matrix. 2 | * 3 | * This routine has been adapted with permission from a Pascal implementation 4 | * (c) 1988 J. C. Nash, "Compact numerical methods for computers", Hilger 1990. 5 | * The A matrix must be pre-allocated with 2n rows and n columns. On calling 6 | * the matrix to be decomposed is contained in the first n rows of A. On return 7 | * the n first rows of A contain the product US and the lower n rows contain V 8 | * (not V'). The S2 vector returns the square of the singular values. 9 | * 10 | * (c) Copyright 1996 by Carl Edward Rasmussen. */ 11 | 12 | #include 13 | #include 14 | 15 | void svd(double **A, double *S2, int n) 16 | { 17 | int i, j, k, EstColRank = n, RotCount = n, SweepCount = 0, 18 | slimit = (n<120) ? 30 : n/4; 19 | double eps = 1e-15, e2 = 10.0*n*eps*eps, tol = 0.1*eps, vt, p, x0, 20 | y0, q, r, c0, s0, d1, d2; 21 | 22 | for (i=0; i= r) { 34 | if (q<=e2*S2[0] || fabs(p)<=tol*q) 35 | RotCount--; 36 | else { 37 | p /= q; r = 1.0-r/q; vt = sqrt(4.0*p*p+r*r); 38 | c0 = sqrt(0.5*(1.0+r/vt)); s0 = p/(vt*c0); 39 | for (i=0; i<2*n; i++) { 40 | d1 = A[i][j]; d2 = A[i][k]; 41 | A[i][j] = d1*c0+d2*s0; A[i][k] = -d1*s0+d2*c0; 42 | } 43 | } 44 | } else { 45 | p /= r; q = q/r-1.0; vt = sqrt(4.0*p*p+q*q); 46 | s0 = sqrt(0.5*(1.0-q/vt)); 47 | if (p<0.0) s0 = -s0; 48 | c0 = p/(vt*s0); 49 | for (i=0; i<2*n; i++) { 50 | d1 = A[i][j]; d2 = A[i][k]; 51 | A[i][j] = d1*c0+d2*s0; A[i][k] = -d1*s0+d2*c0; 52 | } 53 | } 54 | } 55 | while (EstColRank>2 && S2[EstColRank-1]<=S2[0]*tol+tol*tol) EstColRank--; 56 | } 57 | if (SweepCount > slimit) 58 | printf("Warning: Reached maximum number of sweeps (%d) in SVD routine...\n" 59 | ,slimit); 60 | } 61 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anyio==3.3.4 2 | appnope==0.1.2 3 | argon2-cffi==21.1.0 4 | attrs==21.2.0 5 | Babel==2.9.1 6 | backcall==0.2.0 7 | backports.entry-points-selectable==1.1.0 8 | bleach==4.1.0 9 | certifi==2021.10.8 10 | cffi==1.15.0 11 | charset-normalizer==2.0.7 12 | cycler==0.11.0 13 | debugpy==1.5.1 14 | decorator==5.1.0 15 | defusedxml==0.7.1 16 | distlib==0.3.3 17 | entrypoints==0.3 18 | filelock==3.3.2 19 | fonttools==4.28.1 20 | idna==3.3 21 | ipykernel==6.5.0 22 | ipython==7.29.0 23 | ipython-genutils==0.2.0 24 | jedi==0.18.0 25 | Jinja2==3.0.2 26 | json5==0.9.6 27 | jsonschema==4.1.2 28 | jupyter-client==7.0.6 29 | jupyter-core==4.9.1 30 | jupyter-server==1.11.2 31 | jupyter_micropython_kernel==0.1.3.2 32 | jupyterlab==3.2.1 33 | jupyterlab-pygments==0.1.2 34 | jupyterlab-server==2.8.2 35 | kiwisolver==1.3.2 36 | MarkupSafe==2.0.1 37 | matplotlib==3.5.0 38 | matplotlib-inline==0.1.3 39 | mistune==0.8.4 40 | nbclassic==0.3.4 41 | nbclient==0.5.4 42 | nbconvert==6.2.0 43 | nbformat==5.1.3 44 | nest-asyncio==1.5.1 45 | notebook==6.4.5 46 | numpy==1.21.4 47 | packaging==21.2 48 | padasip==1.1.1 49 | pandas==1.3.4 50 | pandocfilters==1.5.0 51 | parso==0.8.2 52 | pexpect==4.8.0 53 | pickleshare==0.7.5 54 | Pillow==8.4.0 55 | platformdirs==2.4.0 56 | prometheus-client==0.12.0 57 | prompt-toolkit==3.0.21 58 | ptyprocess==0.7.0 59 | pycparser==2.20 60 | Pygments==2.10.0 61 | pyparsing==2.4.7 62 | pyrsistent==0.18.0 63 | pyserial==3.5 64 | python-dateutil==2.8.2 65 | pytz==2021.3 66 | pyzmq==22.3.0 67 | requests==2.26.0 68 | scipy==1.7.2 69 | seaborn==0.11.2 70 | Send2Trash==1.8.0 71 | setuptools-scm==6.3.2 72 | six==1.16.0 73 | sniffio==1.2.0 74 | terminado==0.12.1 75 | testpath==0.5.0 76 | tomli==1.2.2 77 | tornado==6.1 78 | traitlets==5.1.1 79 | urllib3==1.26.7 80 | virtualenv==20.10.0 81 | wcwidth==0.2.5 82 | webencodings==0.5.1 83 | websocket-client==1.2.1 84 | -------------------------------------------------------------------------------- /ui/ssvep_squares.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 |
11 | 12 |
13 |

up

14 |
15 |
16 |
17 |
18 |

left

19 |
20 | 21 |
22 |
23 |

right

24 |
25 |
26 |
27 |
28 |

down

29 |
30 |
31 | 33 | 36 | 37 | 40 | 43 | 46 | 48 | 50 | 53 | 54 | 56 | 59 | 60 | 61 | 64 | 67 | 71 | 72 | 74 | 77 | 80 | 82 | 85 | 86 |
87 |
88 | 89 | 121 | 122 | 123 | --------------------------------------------------------------------------------