├── .gitignore ├── LICENSE ├── README.md ├── configs ├── base.yaml ├── logging │ └── local.yaml ├── model │ ├── dvs │ │ └── small.yaml │ ├── shd │ │ ├── medium.yaml │ │ └── tiny.yaml │ └── ssc │ │ ├── medium.yaml │ │ └── small.yaml ├── system │ └── local.yaml └── task │ ├── dvs-gesture.yaml │ ├── spiking-heidelberg-digits.yaml │ ├── spiking-speech-commands.yaml │ └── tutorial.yaml ├── docs └── figure1.png ├── event_ssm ├── __init__.py ├── dataloading.py ├── layers.py ├── seq_model.py ├── ssm.py ├── ssm_init.py ├── train_utils.py ├── trainer.py └── transform.py ├── requirements.txt ├── run_evaluation.py ├── run_training.py ├── setup.py ├── tutorial_inference.ipynb ├── tutorial_online_inference.ipynb └── tutorial_training.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | *.pyc 162 | 163 | # S5 specific stuff 164 | wandb/ 165 | cache_dir/ 166 | raw_datasets/ 167 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Mark Schoene 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scalable Event-by-event Processing of Neuromorphic Sensory Signals With Deep State-Space Models 2 | ![Figure 1](docs/figure1.png) 3 | This is the official implementation of our paper [Scalable Event-by-event Processing of Neuromorphic Sensory Signals With Deep State-Space Models 4 | ](https://arxiv.org/abs/2404.18508). 5 | The core motivation for this work was the irregular time-series modeling problem presented in the paper [Simplified State Space Layers for Sequence Modeling 6 | ](https://arxiv.org/abs/2208.04933). 7 | We acknowledge the awesome [S5 project](https://github.com/lindermanlab/S5) and the trainer class provided by this [UvA tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/guide4/Research_Projects_with_JAX.html), which highly influenced our code. 8 | 9 | Our project treats a quite general machine learning problem: 10 | Modeling **long sequences** that are **irregularly** sampled by a possibly large number of **asynchronous** sensors. 11 | This problem is particularly present in the field of neuromorphic computing, where event-based sensors emit up to millions events per second from asynchronous channels. 12 | 13 | We show how linear state-space models can be tuned to effectively model asynchronous event-based sequences. 14 | Our contributions are 15 | - Integration of dirac delta coded event streams 16 | - time-invariant input normalization to effectively learn from long event-streams 17 | - formulating neuromorphic event-streams as a language modeling problem with **asynchronous tokens** 18 | - effectively model event-based vision **without frames and without CNNs** 19 | 20 | ## Installation 21 | The project is implemented in [JAX](https://github.com/google/jax) with [Flax](https://flax.readthedocs.io/en/latest/). 22 | By default, we install JAX with GPU support with CUDA >= 12.0. 23 | To install JAX for CPU, replace `jax[cuda]` with `jax[cpu]` in the `requirements.txt` file. 24 | PyTorch is only required for loading data. 25 | Therefore, we install only the CPU version of PyTorch. 26 | Install the requirements with 27 | ```bash 28 | pip install -r requirements.txt 29 | ``` 30 | Install this repository 31 | ```bash 32 | pip install -e . 33 | ``` 34 | We tested with JAX versions between `0.4.20` and `0.4.29`. 35 | Different CUDA and JAX versions might result in slightly different results. 36 | 37 | ## Reproducing experiments 38 | We use the [hydra](https://hydra.cc/docs/intro/) package to manage configurations. 39 | If you are not familiar with hydra, we recommend to read the [documentation](https://hydra.cc/docs/intro/). 40 | 41 | ### Run benchmark tasks 42 | The basic command to run an experiment is 43 | ```bash 44 | python run_training.py 45 | ``` 46 | This will default to running the Spiking Heidelberg Digits (SHD) dataset. 47 | All benchmark tasks are defined by the configurations in `configs/tasks/`, and can be run by specifying the `task` argument. 48 | E.g. run the Spiking Speech Commands (SSC) task with 49 | ```bash 50 | python run_training.py task=spiking-speech-commands 51 | ``` 52 | or run the DVS128 Gestures task with 53 | ```bash 54 | python run_training.py task=dvs-gesture 55 | ``` 56 | 57 | ### Trained models 58 | We provide our best models for [download](https://datashare.tu-dresden.de/s/g2dQCi792B8DqnC). 59 | Check out the `tutorial_inference.ipynb` notebook to see how to load and run inference with these models. 60 | We also provide a script to evaluate the models on the test set 61 | ```bash 62 | python run_evaluation.py task=spiking-speech-commands checkpoint=downloaded/model/SSC 63 | ``` 64 | 65 | 66 | ### Specify HPC system and logging 67 | Many researchers operate on different HPC systems and perhaps log their experiments to multiple platforms. 68 | Therefore, the user can specify configurations for 69 | - different systems (directories for reading data and saving outputs) 70 | - logging methods (e.g. whether to log locally or to [wandb](https://wandb.ai/)) 71 | 72 | By default, the `configs/system/local.yaml` and `configs/logging/local.yaml` configurations are used, respectively. 73 | We suggest to create new configs for the HPC systems and wandb projects you are using. 74 | 75 | For example, to run the model on SSC with your custom wandb logging config and your custom HPC specification do 76 | ```bash 77 | python run_training.py task=spiking-speech-commands logging=wandb system=hpc 78 | ``` 79 | where `configs/logging/wandb.yaml` should look like 80 | ```yaml 81 | log_dir: ${output_dir} 82 | interval: 1000 83 | wandb: False 84 | summary_metric: "Performance/Validation accuracy" 85 | project: wandb_project_name 86 | entity: wandb_entity_name 87 | ``` 88 | and `configs/system/hpc.yaml` should specify data and output directories 89 | ```yaml 90 | # @package _global_ 91 | 92 | data_dir: my/fast/storage/location/data 93 | output_dir: my/job/output/location/${task.name}/${oc.env:SLURM_JOB_ID}/${now:%Y-%m-%d-%H-%M-%S} 94 | ``` 95 | The string `${task.name}/${oc.env:SLURM_JOB_ID}/${now:%Y-%m-%d-%H-%M-%S}` will create subdirectories named by task, slurm job ID, and date, 96 | which we found useful in practice. 97 | This specification of the `output_dir` is not required though. 98 | 99 | ## Tutorials 100 | To get started with event-based state-space models, we created tutorials for training and inference. 101 | - `tutorial_training.ipynb` shows how to train a model on a reduced version of the Spiking Heidelberg Digits with just two classes. The model converges after few minutes on CPUs. 102 | - `tutorial_inference.ipynb` shows how to load a trained model and run inference. The models are available for download from the provided [download link](https://datashare.tu-dresden.de/s/g2dQCi792B8DqnC). 103 | - `tutorial_online_inference.ipynb` runs event-by-event inference with batch size one (online inference) on the DVS128 Gestures dataset and measures the throughput of the model. 104 | 105 | ## Help and support 106 | We are eager to help you with any questions or issues you might have. 107 | Please use the GitHub issue tracker for questions and to report issues. 108 | 109 | ## Citation 110 | Please use the following when citing our work: 111 | ``` 112 | @misc{Schoene2024, 113 | title={Scalable Event-by-event Processing of Neuromorphic Sensory Signals With Deep State-Space Models}, 114 | author={Mark Schöne and Neeraj Mohan Sushma and Jingyue Zhuge and Christian Mayr and Anand Subramoney and David Kappel}, 115 | year={2024}, 116 | eprint={2404.18508}, 117 | archivePrefix={arXiv}, 118 | primaryClass={cs.LG} 119 | } 120 | ``` 121 | -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - system: local 4 | - task: spiking-heidelberg-digits 5 | - logging: local 6 | 7 | seed: 1234 8 | checkpoint: null 9 | 10 | hydra: 11 | run: 12 | dir: ${output_dir}/hydra-outputs/${now:%Y-%m-%d-%H-%M-%S} -------------------------------------------------------------------------------- /configs/logging/local.yaml: -------------------------------------------------------------------------------- 1 | log_dir: ${output_dir} 2 | interval: 1000 3 | wandb: False 4 | summary_metric: "Performance/Validation accuracy" 5 | project: ??? 6 | entity: ??? -------------------------------------------------------------------------------- /configs/model/dvs/small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | ssm_init: 5 | C_init: lecun_normal 6 | dt_min: 0.001 7 | dt_max: 0.1 8 | conj_sym: false 9 | clip_eigs: true 10 | ssm: 11 | discretization: async 12 | d_model: 128 13 | d_ssm: 128 14 | ssm_block_size: 16 15 | num_stages: 2 16 | num_layers_per_stage: 3 17 | dropout: 0.25 18 | classification_mode: timepool 19 | prenorm: true 20 | batchnorm: false 21 | bn_momentum: 0.95 22 | pooling_stride: 16 23 | pooling_mode: timepool 24 | state_expansion_factor: 2 25 | -------------------------------------------------------------------------------- /configs/model/shd/medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | ssm_init: 5 | C_init: lecun_normal 6 | dt_min: 0.004 7 | dt_max: 0.1 8 | conj_sym: false 9 | clip_eigs: false 10 | ssm: 11 | discretization: async 12 | d_model: 96 13 | d_ssm: 128 14 | ssm_block_size: 8 15 | num_stages: 2 16 | num_layers_per_stage: 3 17 | dropout: 0.23 18 | classification_mode: pool 19 | prenorm: true 20 | batchnorm: false 21 | bn_momentum: 0.95 22 | pooling_stride: 8 23 | pooling_mode: avgpool 24 | state_expansion_factor: 1 -------------------------------------------------------------------------------- /configs/model/shd/tiny.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | ssm_init: 5 | C_init: lecun_normal 6 | dt_min: 0.004 7 | dt_max: 0.1 8 | conj_sym: false 9 | clip_eigs: false 10 | ssm: 11 | discretization: async 12 | d_model: 16 13 | d_ssm: 16 14 | ssm_block_size: 8 15 | num_stages: 1 16 | num_layers_per_stage: 6 17 | dropout: 0.1 18 | classification_mode: timepool 19 | prenorm: true 20 | batchnorm: false 21 | bn_momentum: 0.95 22 | pooling_stride: 32 23 | pooling_mode: timepool 24 | state_expansion_factor: 1 25 | -------------------------------------------------------------------------------- /configs/model/ssc/medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | ssm_init: 5 | C_init: lecun_normal 6 | dt_min: 0.0015 7 | dt_max: 0.1 8 | conj_sym: true 9 | clip_eigs: false 10 | ssm: 11 | discretization: async 12 | d_model: 96 13 | d_ssm: 128 14 | ssm_block_size: 16 15 | num_stages: 2 16 | num_layers_per_stage: 3 17 | dropout: 0.23 18 | classification_mode: pool 19 | prenorm: true 20 | batchnorm: true 21 | bn_momentum: 0.95 22 | pooling_stride: 8 23 | pooling_mode: avgpool 24 | state_expansion_factor: 2 25 | -------------------------------------------------------------------------------- /configs/model/ssc/small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | ssm_init: 5 | C_init: lecun_normal 6 | dt_min: 0.002 7 | dt_max: 0.1 8 | conj_sym: true 9 | clip_eigs: false 10 | ssm: 11 | discretization: async 12 | d_model: 64 13 | d_ssm: 64 14 | ssm_block_size: 8 15 | num_stages: 1 16 | num_layers_per_stage: 6 17 | dropout: 0.27 18 | classification_mode: timepool 19 | prenorm: true 20 | batchnorm: true 21 | bn_momentum: 0.95 22 | pooling_stride: 8 23 | pooling_mode: timepool 24 | state_expansion_factor: 1 -------------------------------------------------------------------------------- /configs/system/local.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | data_dir: ./data 4 | output_dir: ./outputs/${now:%Y-%m-%d-%H-%M-%S} 5 | checkpoint_dir: ./checkpoints -------------------------------------------------------------------------------- /configs/task/dvs-gesture.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /model: dvs/small 4 | 5 | task: 6 | name: dvs-gesture-classification 7 | 8 | training: 9 | num_epochs: 100 10 | per_device_batch_size: 16 11 | per_device_eval_batch_size: 4 12 | num_workers: 4 13 | time_jitter: 5 14 | spatial_jitter: 0.3 15 | noise: 0.0 16 | drop_event: 0.05 17 | time_skew: 1.12 18 | max_roll: 32 19 | max_angle: 10 20 | max_scale: 1.2 21 | max_drop_chunk: 0.02 22 | cut_mix: 0.4 23 | pad_unit: 524288 24 | slice_events: 65536 25 | validate_on_test: true 26 | 27 | optimizer: 28 | ssm_base_lr: 0.000012 29 | lr_factor: 6 30 | warmup_epochs: 10 31 | ssm_weight_decay: 0.0 32 | weight_decay: 0.02 33 | schedule: cosine 34 | accumulation_steps: 4 -------------------------------------------------------------------------------- /configs/task/spiking-heidelberg-digits.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /model: shd/medium 4 | 5 | task: 6 | name: shd-classification 7 | 8 | training: 9 | num_epochs: 30 10 | per_device_batch_size: 32 11 | per_device_eval_batch_size: 128 12 | num_workers: 4 13 | time_jitter: 1 14 | spatial_jitter: 0.55 15 | noise: 35 16 | max_drop_chunk: 0.02 17 | drop_event: 0.1 18 | time_skew: 1.2 19 | cut_mix: 0.3 20 | pad_unit: 8192 21 | validate_on_test: true 22 | 23 | optimizer: 24 | ssm_base_lr: 1.7e-5 25 | lr_factor: 10 26 | warmup_epochs: 3 27 | ssm_weight_decay: 0.0 28 | weight_decay: 0.03 29 | schedule: cosine 30 | accumulation_steps: 1 -------------------------------------------------------------------------------- /configs/task/spiking-speech-commands.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /model: ssc/medium 4 | 5 | task: 6 | name: ssc-classification 7 | 8 | training: 9 | num_epochs: 200 10 | per_device_batch_size: 64 11 | per_device_eval_batch_size: 128 12 | num_workers: 4 13 | time_jitter: 3 14 | spatial_jitter: 1.0 15 | noise: 100 16 | drop_event: 0.1 17 | max_drop_chunk: 0.02 18 | cut_mix: 0.3 19 | time_skew: 1.05 20 | pad_unit: 8192 21 | 22 | optimizer: 23 | ssm_base_lr: 0.000005 24 | lr_factor: 5 25 | warmup_epochs: 20 26 | ssm_weight_decay: 0.0 27 | weight_decay: 0.05 28 | schedule: cosine 29 | accumulation_steps: 1 -------------------------------------------------------------------------------- /configs/task/tutorial.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /model: shd/tiny 4 | 5 | task: 6 | name: shd-classification 7 | 8 | training: 9 | num_epochs: 5 10 | per_device_batch_size: 16 11 | per_device_eval_batch_size: 16 12 | num_workers: 4 13 | time_jitter: 1 14 | spatial_jitter: 0.55 15 | noise: 35 16 | max_drop_chunk: 0.02 17 | drop_event: 0.1 18 | time_skew: 1.2 19 | cut_mix: 0.3 20 | pad_unit: 8192 21 | validate_on_test: false 22 | 23 | optimizer: 24 | ssm_base_lr: 5e-5 25 | lr_factor: 10 26 | warmup_epochs: 1 27 | ssm_weight_decay: 0.0 28 | weight_decay: 0.01 29 | schedule: cosine 30 | accumulation_steps: 1 -------------------------------------------------------------------------------- /docs/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Efficient-Scalable-Machine-Learning/event-ssm/d9ceb07c6f669086537e279dfbe3d8cdb5a70fbe/docs/figure1.png -------------------------------------------------------------------------------- /event_ssm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Efficient-Scalable-Machine-Learning/event-ssm/d9ceb07c6f669086537e279dfbe3d8cdb5a70fbe/event_ssm/__init__.py -------------------------------------------------------------------------------- /event_ssm/dataloading.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | from typing import Callable, Optional, TypeVar, Dict, Tuple, List, Union 4 | import tonic 5 | from functools import partial 6 | import numpy as np 7 | from event_ssm.transform import Identity, Roll, Rotate, Scale, DropEventChunk, Jitter1D, OneHotLabels, cut_mix_augmentation 8 | 9 | DEFAULT_CACHE_DIR_ROOT = Path('./cache_dir/') 10 | 11 | DataLoader = TypeVar('DataLoader') 12 | InputType = [str, Optional[int], Optional[int]] 13 | 14 | 15 | class Data: 16 | """ 17 | Data class for storing dataset specific information 18 | """ 19 | def __init__( 20 | self, 21 | n_classes: int, 22 | num_embeddings: int, 23 | train_size: int 24 | ): 25 | self.n_classes = n_classes 26 | self.num_embeddings = num_embeddings 27 | self.train_size = train_size 28 | 29 | 30 | def event_stream_collate_fn(batch, resolution, pad_unit, cut_mix=0.0, no_time_information=False): 31 | """ 32 | Collate function to turn event stream data into tokens ready for the JAX model 33 | 34 | :param batch: list of tuples of (events, target) 35 | :param resolution: resolution of the event stream 36 | :param pad_unit: padding unit for the tokens. All sequences will be padded to integer multiples of this unit. 37 | This option results in JAX compiling multiple GPU kernels for different sequence lengths, 38 | which might slow down compilation time, but improves throughput for the rest of the training process. 39 | :param cut_mix: probability of applying cut mix augmentation 40 | :param no_time_information: if True, the time information is ignored and all events are treated as if they were 41 | recorded sampled at uniform time intervals. 42 | This option is only used for ablation studies. 43 | """ 44 | # x are inputs, y are targets, z are aux data 45 | x, y, *z = zip(*batch) 46 | assert len(z) == 0 47 | batch_size_one = len(x) == 1 48 | 49 | # apply cut mix augmentation 50 | if np.random.rand() < cut_mix: 51 | x, y = cut_mix_augmentation(x, y) 52 | 53 | # set labels to numpy array 54 | y = np.stack(y) 55 | 56 | # integration time steps are the difference between two consequtive time stamps 57 | if no_time_information: 58 | timesteps = [np.ones_like(e['t'][:-1]) for e in x] 59 | else: 60 | timesteps = [np.diff(e['t']) for e in x] 61 | 62 | # NOTE: since timesteps are deltas, their length is L - 1, and we have to remove the last token in the following 63 | 64 | # process tokens for single input dim (e.g. audio) 65 | if len(resolution) == 1: 66 | tokens = [e['x'][:-1].astype(np.int32) for e in x] 67 | elif len(resolution) == 2: 68 | tokens = [(e['x'][:-1] * e['y'][:-1] + np.prod(resolution) * e['p'][:-1].astype(np.int32)).astype(np.int32) for e in x] 69 | else: 70 | raise ValueError('resolution must contain 1 or 2 elements') 71 | 72 | # get padding lengths 73 | lengths = np.array([len(e) for e in timesteps], dtype=np.int32) 74 | pad_length = (lengths.max() // pad_unit) * pad_unit + pad_unit 75 | 76 | # pad tokens with -1, which results in a zero vector with embedding look-ups 77 | tokens = np.stack( 78 | [np.pad(e, (0, pad_length - len(e)), mode='constant', constant_values=-1) for e in tokens]) 79 | timesteps = np.stack( 80 | [np.pad(e, (0, pad_length - len(e)), mode='constant', constant_values=0) for e in timesteps]) 81 | 82 | # timesteps are in micro seconds... transform to milliseconds 83 | timesteps = timesteps / 1000 84 | 85 | if batch_size_one: 86 | lengths = lengths[None, ...] 87 | 88 | return tokens, y, timesteps, lengths 89 | 90 | 91 | def event_stream_dataloader( 92 | train_data, 93 | val_data, 94 | test_data, 95 | batch_size, 96 | eval_batch_size, 97 | train_collate_fn, 98 | eval_collate_fn, 99 | rng, 100 | num_workers=0, 101 | shuffle_training=True 102 | ): 103 | """ 104 | Create dataloaders for training, validation and testing 105 | 106 | :param train_data: training dataset 107 | :param val_data: validation dataset 108 | :param test_data: test dataset 109 | :param batch_size: batch size for training 110 | :param eval_batch_size: batch size for evaluation 111 | :param train_collate_fn: collate function for training 112 | :param eval_collate_fn: collate function for evaluation 113 | :param rng: random number generator 114 | :param num_workers: number of workers for data loading 115 | :param shuffle_training: whether to shuffle the training data 116 | 117 | :return: train_loader, val_loader, test_loader 118 | """ 119 | def dataloader(dset, bsz, collate_fn, shuffle, drop_last): 120 | return torch.utils.data.DataLoader( 121 | dset, 122 | batch_size=bsz, 123 | drop_last=drop_last, 124 | collate_fn=collate_fn, 125 | shuffle=shuffle, 126 | generator=rng, 127 | num_workers=num_workers 128 | ) 129 | train_loader = dataloader(train_data, batch_size, train_collate_fn, shuffle=shuffle_training, drop_last=True) 130 | val_loader = dataloader(val_data, eval_batch_size, eval_collate_fn, shuffle=False, drop_last=True) 131 | test_loader = dataloader(test_data, eval_batch_size, eval_collate_fn, shuffle=False, drop_last=False) 132 | return train_loader, val_loader, test_loader 133 | 134 | 135 | def create_events_shd_classification_dataset( 136 | cache_dir: Union[str, Path] = DEFAULT_CACHE_DIR_ROOT, 137 | per_device_batch_size: int = 32, 138 | per_device_eval_batch_size: int = 64, 139 | world_size: int = 1, 140 | num_workers: int = 0, 141 | seed: int = 42, 142 | time_jitter: float = 100, 143 | spatial_jitter: float = 1.0, 144 | max_drop_chunk: float = 0.1, 145 | noise: int = 100, 146 | drop_event: float = 0.1, 147 | time_skew: float = 1.1, 148 | cut_mix: float = 0.5, 149 | pad_unit: int = 8192, 150 | validate_on_test: bool = False, 151 | no_time_information: bool = False, 152 | **kwargs 153 | ) -> Tuple[DataLoader, DataLoader, DataLoader, Data]: 154 | """ 155 | creates a view of the spiking heidelberg digits dataset 156 | 157 | :param cache_dir: (str): where to store the dataset 158 | :param bsz: (int): Batch size. 159 | :param seed: (int) Seed for shuffling data. 160 | :param time_jitter: (float) Standard deviation of the time jitter. 161 | :param spatial_jitter: (float) Standard deviation of the spatial jitter. 162 | :param max_drop_chunk: (float) Maximum fraction of events to drop in a single chunk. 163 | :param noise: (int) Number of noise events to add. 164 | :param drop_event: (float) Probability of dropping an event. 165 | :param time_skew: (float) Time skew factor. 166 | :param cut_mix: (float) Probability of applying cut mix augmentation. 167 | :param pad_unit: (int) Padding unit for the tokens. See collate function for more details 168 | :param validate_on_test: (bool) If True, use the test set for validation. 169 | Else use a random validation split from the test set. 170 | :param no_time_information: (bool) Whether to ignore the time information in the events. 171 | 172 | :return: train_loader, val_loader, test_loader, data 173 | """ 174 | print("[*] Generating Spiking Heidelberg Digits Classification Dataset") 175 | 176 | if seed is not None: 177 | rng = torch.Generator() 178 | rng.manual_seed(seed) 179 | else: 180 | rng = None 181 | 182 | sensor_size = (700, 1, 1) 183 | 184 | transforms = tonic.transforms.Compose([ 185 | tonic.transforms.DropEvent(p=drop_event), 186 | DropEventChunk(p=0.3, max_drop_size=max_drop_chunk), 187 | Jitter1D(sensor_size=sensor_size, var=spatial_jitter), 188 | tonic.transforms.TimeSkew(coefficient=(1 / time_skew, time_skew), offset=0), 189 | tonic.transforms.TimeJitter(std=time_jitter, clip_negative=False, sort_timestamps=True), 190 | tonic.transforms.UniformNoise(sensor_size=sensor_size, n=(0, noise)) 191 | ]) 192 | target_transforms = OneHotLabels(num_classes=20) 193 | 194 | train_data = tonic.datasets.SHD(save_to=cache_dir, train=True, transform=transforms, target_transform=target_transforms) 195 | val_data = tonic.datasets.SHD(save_to=cache_dir, train=True, target_transform=target_transforms) 196 | test_data = tonic.datasets.SHD(save_to=cache_dir, train=False, target_transform=target_transforms) 197 | 198 | # create validation set 199 | if validate_on_test: 200 | print("[*] WARNING: Using test set for validation") 201 | val_data = tonic.datasets.SHD(save_to=cache_dir, train=False, target_transform=target_transforms) 202 | else: 203 | val_length = int(0.1 * len(train_data)) 204 | indices = torch.randperm(len(train_data), generator=rng) 205 | train_data = torch.utils.data.Subset(train_data, indices[:-val_length]) 206 | val_data = torch.utils.data.Subset(val_data, indices[-val_length:]) 207 | 208 | collate_fn = partial(event_stream_collate_fn, resolution=(700,), pad_unit=pad_unit, no_time_information=no_time_information) 209 | train_loader, val_loader, test_loader = event_stream_dataloader( 210 | train_data, val_data, test_data, 211 | train_collate_fn=partial(collate_fn, cut_mix=cut_mix), 212 | eval_collate_fn=collate_fn, 213 | batch_size=per_device_batch_size * world_size, eval_batch_size=per_device_eval_batch_size * world_size, 214 | rng=rng, num_workers=num_workers, shuffle_training=True 215 | ) 216 | data = Data( 217 | n_classes=20, num_embeddings=700, train_size=len(train_data) 218 | ) 219 | return train_loader, val_loader, test_loader, data 220 | 221 | 222 | def create_events_ssc_classification_dataset( 223 | cache_dir: Union[str, Path] = DEFAULT_CACHE_DIR_ROOT, 224 | per_device_batch_size: int = 32, 225 | per_device_eval_batch_size: int = 64, 226 | world_size: int = 1, 227 | num_workers: int = 0, 228 | seed: int = 42, 229 | time_jitter: float = 100, 230 | spatial_jitter: float = 1.0, 231 | max_drop_chunk: float = 0.1, 232 | noise: int = 100, 233 | drop_event: float = 0.1, 234 | time_skew: float = 1.1, 235 | cut_mix: float = 0.5, 236 | pad_unit: int = 8192, 237 | no_time_information: bool = False, 238 | **kwargs 239 | ) -> Tuple[DataLoader, DataLoader, DataLoader, Data]: 240 | """ 241 | creates a view of the spiking speech commands dataset 242 | 243 | :param cache_dir: (str): where to store the dataset 244 | :param bsz: (int): Batch size. 245 | :param seed: (int) Seed for shuffling data. 246 | :param time_jitter: (float) Standard deviation of the time jitter. 247 | :param spatial_jitter: (float) Standard deviation of the spatial jitter. 248 | :param max_drop_chunk: (float) Maximum fraction of events to drop in a single chunk. 249 | :param noise: (int) Number of noise events to add. 250 | :param drop_event: (float) Probability of dropping an event. 251 | :param time_skew: (float) Time skew factor. 252 | :param cut_mix: (float) Probability of applying cut mix augmentation. 253 | :param pad_unit: (int) Padding unit for the tokens. See collate function for more details 254 | :param no_time_information: (bool) Whether to ignore the time information in the events. 255 | 256 | :return: train_loader, val_loader, test_loader, data 257 | """ 258 | print("[*] Generating Spiking Speech Commands Classification Dataset") 259 | 260 | if seed is not None: 261 | rng = torch.Generator() 262 | rng.manual_seed(seed) 263 | else: 264 | rng = None 265 | 266 | sensor_size = (700, 1, 1) 267 | 268 | transforms = tonic.transforms.Compose([ 269 | tonic.transforms.DropEvent(p=drop_event), 270 | DropEventChunk(p=0.3, max_drop_size=max_drop_chunk), 271 | Jitter1D(sensor_size=sensor_size, var=spatial_jitter), 272 | tonic.transforms.TimeSkew(coefficient=(1 / time_skew, time_skew), offset=0), 273 | tonic.transforms.TimeJitter(std=time_jitter, clip_negative=False, sort_timestamps=True), 274 | tonic.transforms.UniformNoise(sensor_size=sensor_size, n=(0, noise)) 275 | ]) 276 | target_transforms = OneHotLabels(num_classes=35) 277 | 278 | train_data = tonic.datasets.SSC(save_to=cache_dir, split='train', transform=transforms, target_transform=target_transforms) 279 | val_data = tonic.datasets.SSC(save_to=cache_dir, split='valid', target_transform=target_transforms) 280 | test_data = tonic.datasets.SSC(save_to=cache_dir, split='test', target_transform=target_transforms) 281 | 282 | collate_fn = partial(event_stream_collate_fn, resolution=(700,), pad_unit=pad_unit, no_time_information=no_time_information) 283 | train_loader, val_loader, test_loader = event_stream_dataloader( 284 | train_data, val_data, test_data, 285 | train_collate_fn=partial(collate_fn, cut_mix=cut_mix), 286 | eval_collate_fn=collate_fn, 287 | batch_size=per_device_batch_size * world_size, eval_batch_size=per_device_eval_batch_size * world_size, 288 | rng=rng, num_workers=num_workers, shuffle_training=True 289 | ) 290 | 291 | data = Data( 292 | n_classes=35, num_embeddings=700, train_size=len(train_data) 293 | ) 294 | return train_loader, val_loader, test_loader, data 295 | 296 | 297 | def create_events_dvs_gesture_classification_dataset( 298 | cache_dir: Union[str, Path] = DEFAULT_CACHE_DIR_ROOT, 299 | per_device_batch_size: int = 32, 300 | per_device_eval_batch_size: int = 64, 301 | world_size: int = 1, 302 | num_workers: int = 0, 303 | seed: int = 42, 304 | slice_events: int = 0, 305 | pad_unit: int = 2 ** 19, 306 | # Augmentation parameters 307 | time_jitter: float = 100, 308 | spatial_jitter: float = 1.0, 309 | noise: int = 100, 310 | drop_event: float = 0.1, 311 | time_skew: float = 1.1, 312 | cut_mix: float = 0.5, 313 | downsampling: int = 1, 314 | max_roll: int = 4, 315 | max_angle: float = 10, 316 | max_scale: float = 1.5, 317 | max_drop_chunk: float = 0.1, 318 | validate_on_test: bool = False, 319 | **kwargs 320 | ) -> Tuple[DataLoader, DataLoader, DataLoader, Data]: 321 | """ 322 | creates a view of the DVS Gesture dataset 323 | 324 | :param cache_dir: (str): where to store the dataset 325 | :param bsz: (int): Batch size. 326 | :param seed: (int) Seed for shuffling data. 327 | :param slice_events: (int) Number of events per slice. 328 | :param pad_unit: (int) Padding unit for the tokens. See collate function for more details 329 | :param time_jitter: (float) Standard deviation of the time jitter. 330 | :param spatial_jitter: (float) Standard deviation of the spatial jitter. 331 | :param noise: (int) Number of noise events to add. 332 | :param drop_event: (float) Probability of dropping an event. 333 | :param time_skew: (float) Time skew factor. 334 | :param cut_mix: (float) Probability of applying cut mix augmentation. 335 | :param downsampling: (int) Downsampling factor. 336 | :param max_roll: (int) Maximum number of pixels to roll the events. 337 | :param max_angle: (float) Maximum angle to rotate the events. 338 | :param max_scale: (float) Maximum scale factor to scale the events. 339 | :param max_drop_chunk: (float) Maximum fraction of events to drop in a single chunk. 340 | :param validate_on_test: (bool) If True, use the test set for validation. 341 | Else use a random validation split from the test set. 342 | 343 | :return: train_loader, val_loader, test_loader, data 344 | """ 345 | print("[*] Generating DVS Gesture Classification Dataset") 346 | 347 | assert time_skew > 1, "time_skew must be greater than 1" 348 | 349 | if seed is not None: 350 | rng = torch.Generator() 351 | rng.manual_seed(seed) 352 | else: 353 | rng = None 354 | 355 | orig_sensor_size = (128, 128, 2) 356 | new_sensor_size = (128 // downsampling, 128 // downsampling, 2) 357 | train_transforms = [ 358 | # Event transformations 359 | DropEventChunk(p=0.3, max_drop_size=max_drop_chunk), 360 | tonic.transforms.DropEvent(p=drop_event), 361 | tonic.transforms.UniformNoise(sensor_size=new_sensor_size, n=(0, noise)), 362 | # Time tranformations 363 | tonic.transforms.TimeSkew(coefficient=(1 / time_skew, time_skew), offset=0), 364 | tonic.transforms.TimeJitter(std=time_jitter, clip_negative=False, sort_timestamps=True), 365 | # Spatial transformations 366 | tonic.transforms.SpatialJitter(sensor_size=orig_sensor_size, var_x=spatial_jitter, var_y=spatial_jitter, clip_outliers=True), 367 | tonic.transforms.Downsample(sensor_size=orig_sensor_size, target_size=new_sensor_size[:2]) if downsampling > 1 else Identity(), 368 | # Geometric tranformations 369 | Roll(sensor_size=new_sensor_size, p=0.3, max_roll=max_roll), 370 | Rotate(sensor_size=new_sensor_size, p=0.3, max_angle=max_angle), 371 | Scale(sensor_size=new_sensor_size, p=0.3, max_scale=max_scale), 372 | ] 373 | 374 | train_transforms = tonic.transforms.Compose(train_transforms) 375 | test_transforms = tonic.transforms.Compose([ 376 | tonic.transforms.Downsample(sensor_size=orig_sensor_size, target_size=new_sensor_size[:2]) if downsampling > 1 else Identity(), 377 | ]) 378 | target_transforms = OneHotLabels(num_classes=11) 379 | 380 | TrainData = partial(tonic.datasets.DVSGesture, save_to=cache_dir, train=True) 381 | TestData = partial(tonic.datasets.DVSGesture, save_to=cache_dir, train=False) 382 | 383 | # create validation set 384 | if validate_on_test: 385 | print("[*] WARNING: Using test set for validation") 386 | val_data = TestData(transform=test_transforms, target_transform=target_transforms) 387 | else: 388 | # create train validation split 389 | val_data = TrainData(transform=test_transforms, target_transform=target_transforms) 390 | val_length = int(0.2 * len(val_data)) 391 | indices = torch.randperm(len(val_data), generator=rng) 392 | val_data = torch.utils.data.Subset(val_data, indices[-val_length:]) 393 | 394 | # if slice event count is given, train on slices of the training data 395 | if slice_events > 0: 396 | slicer = tonic.slicers.SliceByEventCount(event_count=slice_events, overlap=slice_events // 2, include_incomplete=True) 397 | train_subset = torch.utils.data.Subset(TrainData(), indices[:-val_length]) if not validate_on_test else TrainData() 398 | train_data = tonic.sliced_dataset.SlicedDataset( 399 | dataset=train_subset, 400 | slicer=slicer, 401 | transform=train_transforms, 402 | target_transform=target_transforms, 403 | metadata_path=None 404 | ) 405 | else: 406 | train_data = torch.utils.data.Subset( 407 | TrainData(transform=train_transforms, target_transform=target_transforms), 408 | indices[:-val_length] 409 | ) if not validate_on_test else TrainData(transform=train_transforms) 410 | 411 | # Always evaluate on the full sequences 412 | test_data = TestData(transform=test_transforms, target_transform=target_transforms) 413 | 414 | # define collate functions 415 | train_collate_fn = partial( 416 | event_stream_collate_fn, 417 | resolution=new_sensor_size[:2], 418 | pad_unit=slice_events if (slice_events != 0 and slice_events < pad_unit) else pad_unit, 419 | cut_mix=cut_mix 420 | ) 421 | eval_collate_fn = partial( 422 | event_stream_collate_fn, 423 | resolution=new_sensor_size[:2], 424 | pad_unit=pad_unit, 425 | ) 426 | train_loader, val_loader, test_loader = event_stream_dataloader( 427 | train_data, val_data, test_data, 428 | train_collate_fn=train_collate_fn, 429 | eval_collate_fn=eval_collate_fn, 430 | batch_size=per_device_batch_size * world_size, eval_batch_size=per_device_eval_batch_size * world_size, 431 | rng=rng, num_workers=num_workers, shuffle_training=True 432 | ) 433 | 434 | data = Data( 435 | n_classes=11, num_embeddings=np.prod(new_sensor_size), train_size=len(train_data) 436 | ) 437 | return train_loader, val_loader, test_loader, data 438 | 439 | 440 | Datasets = { 441 | "shd-classification": create_events_shd_classification_dataset, 442 | "ssc-classification": create_events_ssc_classification_dataset, 443 | "dvs-gesture-classification": create_events_dvs_gesture_classification_dataset, 444 | } 445 | -------------------------------------------------------------------------------- /event_ssm/layers.py: -------------------------------------------------------------------------------- 1 | from flax import linen as nn 2 | import jax 3 | from functools import partial 4 | 5 | 6 | class EventPooling(nn.Module): 7 | """ 8 | Subsampling layer for event sequences. 9 | """ 10 | stride: int = 1 11 | mode: str = "last" 12 | eps: float = 1e-6 13 | 14 | def __call__(self, x, integration_timesteps): 15 | """ 16 | Compute the pooled (L/stride)xH output given an LxH input. 17 | :param x: input sequence (L, d_model) 18 | :param integration_timesteps: the integration timesteps for the SSM 19 | :return: output sequence (L/stride, d_model) 20 | """ 21 | if self.stride == 1: 22 | raise ValueError("Stride 1 not supported for pooling") 23 | 24 | else: 25 | remaining_timesteps = (len(integration_timesteps) // self.stride) * self.stride 26 | new_integration_timesteps = integration_timesteps[:remaining_timesteps].reshape(-1, self.stride).sum(axis=1) 27 | x = x[:remaining_timesteps] 28 | d_model = x.shape[-1] 29 | 30 | if self.mode == 'last': 31 | x = x[::self.stride] 32 | return x, new_integration_timesteps 33 | elif self.mode == 'avgpool': 34 | x = x.reshape(-1, self.stride, d_model).mean(axis=1) 35 | return x, new_integration_timesteps 36 | elif self.mode == 'timepool': 37 | weight = integration_timesteps[:remaining_timesteps, None] + self.eps 38 | x = (x * weight).reshape(-1, self.stride, d_model).sum(axis=1) 39 | x = x / weight.reshape(-1, self.stride, 1).sum(axis=1) 40 | return x, new_integration_timesteps 41 | else: 42 | raise NotImplementedError("Pooling mode: {} not implemented".format(self.stride)) 43 | 44 | 45 | class SequenceStage(nn.Module): 46 | """ 47 | Defines a block of EventSSM layers with the same hidden size and event-resolution 48 | 49 | :param ssm: the SSM to be used (i.e. S5 ssm) 50 | :param d_model_in: this is the feature size of the layer inputs and outputs 51 | we usually refer to this size as H 52 | :param d_model_out: this is the feature size of the layer outputs 53 | :param d_ssm: the size of the state space model 54 | :param ssm_block_size: the block size of the state space model 55 | :param layers_per_stage: the number of S5 layers to stack 56 | :param dropout: dropout rate 57 | :param prenorm: whether to use layernorm before the module or after it 58 | :param batchnorm: If True, use batchnorm instead of layernorm 59 | :param bn_momentum: momentum for batchnorm 60 | :param step_rescale: rescale the integration timesteps by this factor 61 | :param pooling_stride: stride for pooling 62 | :param pooling_mode: pooling mode (last, avgpool, timepool) 63 | :param state_expansion_factor: factor to expand the state space model 64 | """ 65 | ssm: nn.Module 66 | discretization: str 67 | d_model_in: int 68 | d_model_out: int 69 | d_ssm: int 70 | ssm_block_size: int 71 | layers_per_stage: int 72 | dropout: float = 0.0 73 | prenorm: bool = False 74 | batchnorm: bool = False 75 | bn_momentum: float = 0.9 76 | step_rescale: float = 1.0 77 | pooling_stride: int = 1 78 | pooling_mode: str = "last" 79 | state_expansion_factor: int = 1 80 | 81 | @nn.compact 82 | def __call__(self, x, integration_timesteps, train: bool): 83 | """ 84 | Compute the LxH output of the stacked encoder given an Lxd_input input sequence. 85 | 86 | :param x: input sequence (L, d_input) 87 | :param integration_timesteps: the integration timesteps for the SSM 88 | :param train: If True, applies dropout and batch norm from batch statistics 89 | :return: output sequence (L, d_model), integration_timesteps 90 | """ 91 | EventSSMLayer = partial( 92 | SequenceLayer, 93 | ssm=self.ssm, 94 | discretization=self.discretization, 95 | dropout=self.dropout, 96 | d_ssm=self.d_ssm, 97 | block_size=self.ssm_block_size, 98 | prenorm=self.prenorm, 99 | batchnorm=self.batchnorm, 100 | bn_momentum=self.bn_momentum, 101 | step_rescale=self.step_rescale, 102 | ) 103 | 104 | # first layer with pooling 105 | x, integration_timesteps = EventSSMLayer( 106 | d_model_in=self.d_model_in, 107 | d_model_out=self.d_model_out, 108 | pooling_stride=self.pooling_stride, 109 | pooling_mode=self.pooling_mode 110 | )(x, integration_timesteps, train=train) 111 | 112 | # further layers without pooling 113 | for l in range(self.layers_per_stage - 1): 114 | x, integration_timesteps = EventSSMLayer( 115 | d_model_in=self.d_model_out, 116 | d_model_out=self.d_model_out, 117 | pooling_stride=1 118 | )(x, integration_timesteps, train=train) 119 | 120 | return x, integration_timesteps 121 | 122 | 123 | class SequenceLayer(nn.Module): 124 | """ 125 | Defines a single event-ssm layer, with S5 SSM, nonlinearity, 126 | dropout, batch/layer norm, etc. 127 | 128 | :param ssm: the SSM to be used (i.e. S5 ssm) 129 | :param discretization: the discretization method to use (zoh, dirac, async) 130 | :param dropout: dropout rate 131 | :param d_model_in: the input feature size 132 | :param d_model_out: the output feature size 133 | :param d_ssm: the size of the state space model 134 | :param block_size: the block size of the state space model 135 | :param prenorm: whether to use layernorm before the module or after it 136 | :param batchnorm: If True, use batchnorm instead of layernorm 137 | :param bn_momentum: momentum for batchnorm 138 | :param step_rescale: rescale the integration timesteps by this factor 139 | :param pooling_stride: stride for pooling 140 | :param pooling_mode: pooling mode (last, avgpool, timepool) 141 | """ 142 | ssm: nn.Module 143 | discretization: str 144 | dropout: float 145 | d_model_in: int 146 | d_model_out: int 147 | d_ssm: int 148 | block_size: int 149 | prenorm: bool = False 150 | batchnorm: bool = False 151 | bn_momentum: float = 0.90 152 | step_rescale: float = 1.0 153 | pooling_stride: int = 1 154 | pooling_mode: str = "last" 155 | 156 | @nn.compact 157 | def __call__(self, x, integration_timesteps, train: bool): 158 | """ 159 | Compute a layer step 160 | 161 | :param x: input sequence (L, d_model_in) 162 | :param integration_timesteps: the integration timesteps for the SSM 163 | :param train: If True, applies dropout and batch norm from batch statistics 164 | :return: output sequence (L, d_model_out), integration_timesteps 165 | """ 166 | skip = x 167 | 168 | if self.prenorm: 169 | norm = nn.BatchNorm(momentum=self.bn_momentum, axis_name='batch') if self.batchnorm else nn.LayerNorm() 170 | x = norm(x, use_running_average=not train) if self.batchnorm else norm(x) 171 | 172 | # apply state space model 173 | x = self.ssm( 174 | H_in=self.d_model_in, H_out=self.d_model_out, P=self.d_ssm, block_size=self.block_size, 175 | step_rescale=self.step_rescale, discretization=self.discretization, 176 | stride=self.pooling_stride, pooling_mode=self.pooling_mode 177 | )(x, integration_timesteps) 178 | 179 | # non-linear activation function 180 | x1 = nn.Dropout(self.dropout, broadcast_dims=[0], deterministic=not train)(nn.gelu(x)) 181 | x1 = nn.Dense(self.d_model_out)(x1) 182 | x = x * nn.sigmoid(x1) 183 | x = nn.Dropout(self.dropout, broadcast_dims=[0], deterministic=not train)(x) 184 | 185 | if self.pooling_stride > 1: 186 | pool = EventPooling(stride=self.pooling_stride, mode=self.pooling_mode) 187 | skip, integration_timesteps = pool(skip, integration_timesteps) 188 | 189 | if self.d_model_in != self.d_model_out: 190 | skip = nn.Dense(self.d_model_out)(skip) 191 | 192 | x = skip + x 193 | 194 | if not self.prenorm: 195 | norm = nn.BatchNorm(momentum=self.bn_momentum, axis_name='batch') if self.batchnorm else nn.LayerNorm() 196 | x = norm(x, use_running_average=not train) if self.batchnorm else norm(x) 197 | 198 | return x, integration_timesteps 199 | -------------------------------------------------------------------------------- /event_ssm/seq_model.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as np 3 | from flax import linen as nn 4 | from .layers import SequenceStage 5 | 6 | 7 | class StackedEncoderModel(nn.Module): 8 | """ 9 | Defines a stack of S5 layers to be used as an encoder. 10 | 11 | :param ssm: the SSM to be used (i.e. S5 ssm) 12 | :param discretization: the discretization to be used for the SSM 13 | :param d_model: the feature size of the layer inputs and outputs. We usually refer to this size as H 14 | :param d_ssm: the size of the state space model. We usually refer to this size as P 15 | :param ssm_block_size: the block size of the state space model 16 | :param num_stages: the number of S5 layers to stack 17 | :param num_layers_per_stage: the number of EventSSM layers to stack 18 | :param num_embeddings: the number of embeddings to use 19 | :param dropout: dropout rate 20 | :param prenorm: whether to use layernorm before the module or after it 21 | :param batchnorm: If True, use batchnorm instead of layernorm 22 | :param bn_momentum: momentum for batchnorm 23 | :param step_rescale: rescale the integration timesteps by this factor 24 | :param pooling_stride: stride for subsampling 25 | :param pooling_every_n_layers: pool every n layers 26 | :param pooling_mode: pooling mode (last, avgpool, timepool) 27 | :param state_expansion_factor: factor to expand the state space model 28 | """ 29 | ssm: nn.Module 30 | discretization: str 31 | d_model: int 32 | d_ssm: int 33 | ssm_block_size: int 34 | num_stages: int 35 | num_layers_per_stage: int 36 | num_embeddings: int = 0 37 | dropout: float = 0.0 38 | prenorm: bool = False 39 | batchnorm: bool = False 40 | bn_momentum: float = 0.9 41 | step_rescale: float = 1.0 42 | pooling_stride: int = 1 43 | pooling_every_n_layers: int = 1 44 | pooling_mode: str = "last" 45 | state_expansion_factor: int = 1 46 | 47 | def setup(self): 48 | """ 49 | Initializes a linear encoder and the stack of EventSSM layers. 50 | """ 51 | assert self.num_embeddings > 0 52 | self.encoder = nn.Embed(num_embeddings=self.num_embeddings, features=self.d_model) 53 | 54 | # generate strides for the model 55 | stages = [] 56 | d_model_in = self.d_model 57 | d_model_out = self.d_model 58 | d_ssm = self.d_ssm 59 | total_downsampling = 1 60 | for stage in range(self.num_stages): 61 | # pool from the first layer but don't expand the state dim for the first layer 62 | total_downsampling *= self.pooling_stride 63 | 64 | stages.append( 65 | SequenceStage( 66 | ssm=self.ssm, 67 | discretization=self.discretization, 68 | d_model_in=d_model_in, 69 | d_model_out=d_model_out, 70 | d_ssm=d_ssm, 71 | ssm_block_size=self.ssm_block_size, 72 | layers_per_stage=self.num_layers_per_stage, 73 | dropout=self.dropout, 74 | prenorm=self.prenorm, 75 | batchnorm=self.batchnorm, 76 | bn_momentum=self.bn_momentum, 77 | step_rescale=self.step_rescale, 78 | pooling_stride=self.pooling_stride, 79 | pooling_mode=self.pooling_mode 80 | ) 81 | ) 82 | 83 | d_ssm = self.state_expansion_factor * d_ssm 84 | d_model_out = self.state_expansion_factor * d_model_in 85 | 86 | if stage > 0: 87 | d_model_in = self.state_expansion_factor * d_model_in 88 | 89 | self.stages = stages 90 | self.total_downsampling = total_downsampling 91 | 92 | def __call__(self, x, integration_timesteps, train: bool): 93 | """ 94 | Compute the LxH output of the stacked encoder given an Lxd_input 95 | input sequence. 96 | :param x: input sequence (L, d_input) 97 | :param integration_timesteps: the integration timesteps for the SSM 98 | :param train: If True, applies dropout and batch norm from batch statistics 99 | :return: output sequence (L, d_model), integration timesteps 100 | """ 101 | x = self.encoder(x) 102 | for i, stage in enumerate(self.stages): 103 | # apply layer SSM 104 | x, integration_timesteps = stage(x, integration_timesteps, train=train) 105 | return x, integration_timesteps 106 | 107 | 108 | def masked_meanpool(x, lengths): 109 | """ 110 | Helper function to perform mean pooling across the sequence length 111 | when sequences have variable lengths. We only want to pool across 112 | the prepadded sequence length. 113 | 114 | :param x: input sequence (L, d_model) 115 | :param lengths: the original length of the sequence before padding 116 | :return: mean pooled output sequence (d_model) 117 | """ 118 | L = x.shape[0] 119 | mask = np.arange(L) < lengths 120 | return np.sum(mask[..., None]*x, axis=0)/lengths 121 | 122 | 123 | def timepool(x, integration_timesteps): 124 | """ 125 | Helper function to perform weighted mean across the sequence length. 126 | Means are weighted with the integration time steps 127 | 128 | :param x: input sequence (L, d_model) 129 | :param integration_timesteps: the integration timesteps for the SSM 130 | :return: time pooled output sequence (d_model) 131 | """ 132 | T = np.sum(integration_timesteps, axis=0) 133 | integral = np.sum(x * integration_timesteps[..., None], axis=0) 134 | return integral / T 135 | 136 | 137 | def masked_timepool(x, lengths, integration_timesteps, eps=1e-6): 138 | """ 139 | Helper function to perform weighted mean across the sequence length 140 | when sequences have variable lengths. We only want to pool across 141 | the prepadded sequence length. Means are weighted with the integration time steps 142 | 143 | :param x: input sequence (L, d_model) 144 | :param lengths: the original length of the sequence before padding 145 | :param integration_timesteps: the integration timesteps for the SSM 146 | :param eps: small value to avoid division by zero 147 | :return: time pooled output sequence (d_model) 148 | """ 149 | L = x.shape[0] 150 | mask = np.arange(L) < lengths 151 | T = np.sum(integration_timesteps) 152 | 153 | # integrate with time weighting 154 | weight = integration_timesteps[..., None] + eps 155 | integral = np.sum(mask[..., None] * x * weight, axis=0) 156 | return integral / T 157 | 158 | 159 | # Here we call vmap to parallelize across a batch of input sequences 160 | batch_masked_meanpool = jax.vmap(masked_meanpool) 161 | 162 | 163 | class ClassificationModel(nn.Module): 164 | """ 165 | EventSSM classificaton sequence model. This consists of the stacked encoder 166 | (which consists of a linear encoder and stack of S5 layers), mean pooling 167 | across the sequence length, a linear decoder, and a softmax operation. 168 | 169 | :param ssm: the SSM to be used (i.e. S5 ssm) 170 | :param discretization: the discretization to be used for the SSM (zoh, dirac, async) 171 | :param num_classes: the number of classes for the classification task 172 | :param d_model: the feature size of the layer inputs and outputs. We usually refer to this size as H 173 | :param d_ssm: the size of the state space model. We usually refer to this size as P 174 | :param ssm_block_size: the block size of the state space model 175 | :param num_stages: the number of S5 layers to stack 176 | :param num_layers_per_stage: the number of EventSSM layers to stack 177 | :param num_embeddings: the number of embeddings to use 178 | :param dropout: dropout rate 179 | :param classification_mode: the classification mode (pool, timepool, last) 180 | :param prenorm: whether to use layernorm before the module or after it 181 | :param batchnorm: If True, use batchnorm instead of layernorm 182 | :param bn_momentum: momentum for batchnorm 183 | :param step_rescale: rescale the integration timesteps by this factor 184 | :param pooling_stride: stride for subsampling 185 | :param pooling_every_n_layers: pool every n layers 186 | :param pooling_mode: pooling mode (last, avgpool, timepool) 187 | :param state_expansion_factor: factor to expand the state space model 188 | """ 189 | ssm: nn.Module 190 | discretization: str 191 | num_classes: int 192 | d_model: int 193 | d_ssm: int 194 | ssm_block_size: int 195 | num_stages: int 196 | num_layers_per_stage: int 197 | num_embeddings: int = 0 198 | dropout: float = 0.2 199 | classification_mode: str = "pool" 200 | prenorm: bool = False 201 | batchnorm: bool = False 202 | bn_momentum: float = 0.9 203 | step_rescale: float = 1.0 204 | pooling_stride: int = 1 205 | pooling_every_n_layers: int = 1 206 | pooling_mode: str = "last" 207 | state_expansion_factor: int = 1 208 | 209 | def setup(self): 210 | """ 211 | Initializes the stacked EventSSM encoder and a linear decoder. 212 | """ 213 | self.encoder = StackedEncoderModel( 214 | ssm=self.ssm, 215 | discretization=self.discretization, 216 | d_model=self.d_model, 217 | d_ssm=self.d_ssm, 218 | ssm_block_size=self.ssm_block_size, 219 | num_stages=self.num_stages, 220 | num_layers_per_stage=self.num_layers_per_stage, 221 | num_embeddings=self.num_embeddings, 222 | dropout=self.dropout, 223 | prenorm=self.prenorm, 224 | batchnorm=self.batchnorm, 225 | bn_momentum=self.bn_momentum, 226 | step_rescale=self.step_rescale, 227 | pooling_stride=self.pooling_stride, 228 | pooling_every_n_layers=self.pooling_every_n_layers, 229 | pooling_mode=self.pooling_mode, 230 | state_expansion_factor=self.state_expansion_factor 231 | ) 232 | self.decoder = nn.Dense(self.num_classes) 233 | 234 | def __call__(self, x, integration_timesteps, length, train=True): 235 | """ 236 | Compute the size num_classes log softmax output given a 237 | Lxd_input input sequence. 238 | 239 | :param x: input sequence (L, d_input) 240 | :param integration_timesteps: the integration timesteps for the SSM 241 | :param length: the original length of the sequence before padding 242 | :param train: If True, applies dropout and batch norm from batch statistics 243 | 244 | :return: output (num_classes) 245 | """ 246 | # if the sequence is downsampled we need to adjust the length 247 | length = length // self.encoder.total_downsampling 248 | 249 | # run encoder backbone 250 | x, integration_timesteps = self.encoder(x, integration_timesteps, train=train) 251 | 252 | # apply classification head 253 | if self.classification_mode in ["pool"]: 254 | # Perform mean pooling across time 255 | x = masked_meanpool(x, length) 256 | 257 | elif self.classification_mode in ["timepool"]: 258 | # Perform mean pooling across time weighted by integration time steps 259 | x = masked_timepool(x, length, integration_timesteps) 260 | 261 | elif self.classification_mode in ["last"]: 262 | # Just take the last state 263 | x = x[-1] 264 | else: 265 | raise NotImplementedError("Mode must be in ['pool', 'last]") 266 | 267 | x = self.decoder(x) 268 | return x 269 | 270 | 271 | # Here we call vmap to parallelize across a batch of input sequences 272 | BatchClassificationModel = nn.vmap( 273 | ClassificationModel, 274 | in_axes=(0, 0, 0, None), 275 | out_axes=0, 276 | variable_axes={"params": None, "dropout": None, 'batch_stats': None, "cache": 0, "prime": None}, 277 | split_rngs={"params": False, "dropout": True}, axis_name='batch') 278 | -------------------------------------------------------------------------------- /event_ssm/ssm.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import jax 3 | import jax.numpy as np 4 | from jax.scipy.linalg import block_diag 5 | 6 | from flax import linen as nn 7 | from jax.nn.initializers import lecun_normal, normal, glorot_normal 8 | 9 | from .ssm_init import init_CV, init_VinvB, init_log_steps, trunc_standard_normal, make_DPLR_HiPPO 10 | 11 | from .layers import EventPooling 12 | 13 | 14 | def discretize_zoh(Lambda, step_delta, time_delta): 15 | """ 16 | Discretize a diagonalized, continuous-time linear SSM 17 | using zero-order hold method. 18 | This is the default discretization method used by many SSM works including S5. 19 | 20 | :param Lambda: diagonal state matrix (P,) 21 | :param step_delta: discretization step sizes (P,) 22 | :param time_delta: (float32) discretization step sizes (P,) 23 | :return: discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) 24 | """ 25 | Identity = np.ones(Lambda.shape[0]) 26 | Delta = step_delta * time_delta 27 | Lambda_bar = np.exp(Lambda * Delta) 28 | gamma_bar = (1/Lambda * (Lambda_bar-Identity)) 29 | return Lambda_bar, gamma_bar 30 | 31 | 32 | def discretize_dirac(Lambda, step_delta, time_delta): 33 | """ 34 | Discretize a diagonalized, continuous-time linear SSM 35 | with dirac delta input spikes. 36 | :param Lambda: diagonal state matrix (P,) 37 | :param step_delta: discretization step sizes (P,) 38 | :param time_delta: (float32) discretization step sizes (P,) 39 | :return: discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) 40 | """ 41 | Delta = step_delta * time_delta 42 | Lambda_bar = np.exp(Lambda * Delta) 43 | gamma_bar = 1.0 44 | return Lambda_bar, gamma_bar 45 | 46 | 47 | def discretize_async(Lambda, step_delta, time_delta): 48 | """ 49 | Discretize a diagonalized, continuous-time linear SSM 50 | with dirac delta input spikes and appropriate input normalization. 51 | 52 | :param Lambda: diagonal state matrix (P,) 53 | :param step_delta: discretization step sizes (P,) 54 | :param time_delta: (float32) discretization step sizes (P,) 55 | :return: discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) 56 | """ 57 | Identity = np.ones(Lambda.shape[0]) 58 | Lambda_bar = np.exp(Lambda * step_delta * time_delta) 59 | gamma_bar = (1/Lambda * (np.exp(Lambda * step_delta)-Identity)) 60 | return Lambda_bar, gamma_bar 61 | 62 | 63 | # Parallel scan operations 64 | @jax.vmap 65 | def binary_operator(q_i, q_j): 66 | """ 67 | Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. 68 | 69 | :param q_i: tuple containing A_i and Bu_i at position i (P,), (P,) 70 | :param q_j: tuple containing A_j and Bu_j at position j (P,), (P,) 71 | :return: new element ( A_out, Bu_out ) 72 | """ 73 | A_i, b_i = q_i 74 | A_j, b_j = q_j 75 | return A_j * A_i, A_j * b_i + b_j 76 | 77 | 78 | def apply_ssm(Lambda_elements, Bu_elements, C_tilde, conj_sym, stride=1): 79 | """ 80 | Compute the LxH output of discretized SSM given an LxH input. 81 | 82 | :param Lambda_elements: (complex64) discretized state matrix (L, P) 83 | :param Bu_elements: (complex64) discretized inputs projected to state space (L, P) 84 | :param C_tilde: (complex64) output matrix (H, P) 85 | :param conj_sym: (bool) whether conjugate symmetry is enforced 86 | :return: ys: (float32) the SSM outputs (S5 layer preactivations) (L, H) 87 | """ 88 | remaining_timesteps = (Bu_elements.shape[0] // stride) * stride 89 | 90 | _, xs = jax.lax.associative_scan(binary_operator, (Lambda_elements, Bu_elements)) 91 | 92 | xs = xs[:remaining_timesteps:stride] 93 | 94 | if conj_sym: 95 | return jax.vmap(lambda x: 2*(C_tilde @ x).real)(xs) 96 | else: 97 | return jax.vmap(lambda x: (C_tilde @ x).real)(xs) 98 | 99 | 100 | class S5SSM(nn.Module): 101 | H_in: int 102 | H_out: int 103 | P: int 104 | block_size: int 105 | C_init: str 106 | discretization: str 107 | dt_min: float 108 | dt_max: float 109 | conj_sym: bool = True 110 | clip_eigs: bool = False 111 | step_rescale: float = 1.0 112 | stride: int = 1 113 | pooling_mode: str = "last" 114 | 115 | """ 116 | Event-based S5 module 117 | 118 | :param H_in: int, SSM input dimension 119 | :param H_out: int, SSM output dimension 120 | :param P: int, SSM state dimension 121 | :param block_size: int, block size for block-diagonal state matrix 122 | :param C_init: str, initialization method for output matrix C 123 | :param discretization: str, discretization method for event-based SSM 124 | :param dt_min: float, minimum value of log timestep 125 | :param dt_max: float, maximum value of log timestep 126 | :param conj_sym: bool, whether to enforce conjugate symmetry in the state space operator 127 | :param clip_eigs: bool, whether to clip eigenvalues of the state space operator 128 | :param step_rescale: float, rescale factor for step size 129 | :param stride: int, stride for subsampling layer 130 | :param pooling_mode: str, pooling mode for subsampling layer 131 | """ 132 | 133 | def setup(self): 134 | """ 135 | Initializes parameters once and performs discretization each time the SSM is applied to a sequence 136 | """ 137 | 138 | # Initialize state matrix A using approximation to HiPPO-LegS matrix 139 | Lambda, _, B, V, B_orig = make_DPLR_HiPPO(self.block_size) 140 | 141 | blocks = self.P // self.block_size 142 | block_size = self.block_size // 2 if self.conj_sym else self.block_size 143 | local_P = self.P // 2 if self.conj_sym else self.P 144 | 145 | Lambda = Lambda[:block_size] 146 | V = V[:, :block_size] 147 | Vc = V.conj().T 148 | 149 | # If initializing state matrix A as block-diagonal, put HiPPO approximation 150 | # on each block 151 | Lambda = (Lambda * np.ones((blocks, block_size))).ravel() 152 | V = block_diag(*([V] * blocks)) 153 | Vinv = block_diag(*([Vc] * blocks)) 154 | 155 | state_str = f"SSM: {self.H_in} -> {self.P} -> {self.H_out}" 156 | if self.stride > 1: 157 | state_str += f" (stride {self.stride} with pooling mode {self.pooling_mode})" 158 | print(state_str) 159 | 160 | # Initialize diagonal state to state matrix Lambda (eigenvalues) 161 | self.Lambda_re = self.param("Lambda_re", lambda rng, shape: Lambda.real, (None,)) 162 | self.Lambda_im = self.param("Lambda_im", lambda rng, shape: Lambda.imag, (None,)) 163 | 164 | if self.clip_eigs: 165 | self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im 166 | else: 167 | self.Lambda = self.Lambda_re + 1j * self.Lambda_im 168 | 169 | # Initialize input to state (B) matrix 170 | B_init = lecun_normal() 171 | B_shape = (self.P, self.H_in) 172 | self.B = self.param("B", 173 | lambda rng, shape: init_VinvB(B_init, rng, shape, Vinv), 174 | B_shape) 175 | 176 | # Initialize state to output (C) matrix 177 | if self.C_init in ["trunc_standard_normal"]: 178 | C_init = trunc_standard_normal 179 | C_shape = (self.H_out, self.P, 2) 180 | elif self.C_init in ["lecun_normal"]: 181 | C_init = lecun_normal() 182 | C_shape = (self.H_out, self.P, 2) 183 | elif self.C_init in ["complex_normal"]: 184 | C_init = normal(stddev=0.5 ** 0.5) 185 | else: 186 | raise NotImplementedError( 187 | "C_init method {} not implemented".format(self.C_init)) 188 | 189 | if self.C_init in ["complex_normal"]: 190 | C = self.param("C", C_init, (self.H_out, local_P, 2)) 191 | self.C_tilde = C[..., 0] + 1j * C[..., 1] 192 | 193 | else: 194 | self.C = self.param("C", 195 | lambda rng, shape: init_CV(C_init, rng, shape, V), 196 | C_shape) 197 | 198 | self.C_tilde = self.C[..., 0] + 1j * self.C[..., 1] 199 | 200 | # Initialize feedthrough (D) matrix 201 | if self.H_in == self.H_out: 202 | self.D = self.param("D", normal(stddev=1.0), (self.H_in,)) 203 | else: 204 | self.D = self.param("D", glorot_normal(), (self.H_out, self.H_in)) 205 | 206 | # Initialize learnable discretization timescale value 207 | self.log_step = self.param("log_step", 208 | init_log_steps, 209 | (local_P, self.dt_min, self.dt_max)) 210 | 211 | # pooling layer 212 | self.pool = EventPooling(stride=self.stride, mode=self.pooling_mode) 213 | 214 | # Discretize 215 | if self.discretization in ["zoh"]: 216 | self.discretize_fn = discretize_zoh 217 | elif self.discretization in ["dirac"]: 218 | self.discretize_fn = discretize_dirac 219 | elif self.discretization in ["async"]: 220 | self.discretize_fn = discretize_async 221 | else: 222 | raise NotImplementedError("Discretization method {} not implemented".format(self.discretization)) 223 | 224 | def __call__(self, input_sequence, integration_timesteps): 225 | """ 226 | Compute the LxH output of the S5 SSM given an LxH input sequence using a parallel scan. 227 | 228 | :param input_sequence: (float32) input sequence (L, H) 229 | :param integration_timesteps: (float32) integration timesteps (L,) 230 | :return: (float32) output sequence (L, H) 231 | """ 232 | 233 | # discretize on the fly 234 | B = self.B[..., 0] + 1j * self.B[..., 1] 235 | 236 | def discretize_and_project_inputs(u, _timestep): 237 | step = self.step_rescale * np.exp(self.log_step[:, 0]) 238 | Lambda_bar, gamma_bar = self.discretize_fn(self.Lambda, step, _timestep) 239 | Bu = gamma_bar * (B @ u) 240 | return Lambda_bar, Bu 241 | 242 | Lambda_bar_elements, Bu_bar_elements = jax.vmap(discretize_and_project_inputs)(input_sequence, integration_timesteps) 243 | 244 | ys = apply_ssm( 245 | Lambda_bar_elements, 246 | Bu_bar_elements, 247 | self.C_tilde, 248 | self.conj_sym, 249 | stride=self.stride 250 | ) 251 | 252 | if self.stride > 1: 253 | input_sequence, _ = self.pool(input_sequence, integration_timesteps) 254 | 255 | if self.H_in == self.H_out: 256 | Du = jax.vmap(lambda u: self.D * u)(input_sequence) 257 | else: 258 | Du = jax.vmap(lambda u: self.D @ u)(input_sequence) 259 | 260 | return ys + Du 261 | 262 | 263 | def init_S5SSM( 264 | C_init, 265 | dt_min, 266 | dt_max, 267 | conj_sym, 268 | clip_eigs, 269 | ): 270 | """ 271 | Convenience function that will be used to initialize the SSM. 272 | Same arguments as defined in S5SSM above. 273 | """ 274 | return partial(S5SSM, 275 | C_init=C_init, 276 | dt_min=dt_min, 277 | dt_max=dt_max, 278 | conj_sym=conj_sym, 279 | clip_eigs=clip_eigs 280 | ) 281 | -------------------------------------------------------------------------------- /event_ssm/ssm_init.py: -------------------------------------------------------------------------------- 1 | from jax import random 2 | import jax.numpy as np 3 | from jax.nn.initializers import lecun_normal 4 | from jax.numpy.linalg import eigh 5 | 6 | 7 | def make_HiPPO(N): 8 | """ 9 | Create a HiPPO-LegS matrix. 10 | From https://github.com/srush/annotated-s4/blob/main/s4/s4.py 11 | 12 | :params N: int32, state size 13 | :returns: N x N HiPPO LegS matrix 14 | """ 15 | P = np.sqrt(1 + 2 * np.arange(N)) 16 | A = P[:, np.newaxis] * P[np.newaxis, :] 17 | A = np.tril(A) - np.diag(np.arange(N)) 18 | return -A 19 | 20 | 21 | def make_NPLR_HiPPO(N): 22 | """ 23 | Makes components needed for NPLR representation of HiPPO-LegS 24 | From https://github.com/srush/annotated-s4/blob/main/s4/s4.py 25 | 26 | :params N: int32, state size 27 | :returns: N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B 28 | """ 29 | # Make -HiPPO 30 | hippo = make_HiPPO(N) 31 | 32 | # Add in a rank 1 term. Makes it Normal. 33 | P = np.sqrt(np.arange(N) + 0.5) 34 | 35 | # HiPPO also specifies the B matrix 36 | B = np.sqrt(2 * np.arange(N) + 1.0) 37 | return hippo, P, B 38 | 39 | 40 | def make_DPLR_HiPPO(N): 41 | """ 42 | Makes components needed for DPLR representation of HiPPO-LegS 43 | From https://github.com/srush/annotated-s4/blob/main/s4/s4.py 44 | Note, we will only use the diagonal part 45 | 46 | :params N: int32, state size 47 | :returns: eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B, 48 | eigenvectors V, HiPPO B pre-conjugation 49 | """ 50 | A, P, B = make_NPLR_HiPPO(N) 51 | 52 | S = A + P[:, np.newaxis] * P[np.newaxis, :] 53 | 54 | S_diag = np.diagonal(S) 55 | Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) 56 | 57 | # Diagonalize S to V \Lambda V^* 58 | Lambda_imag, V = eigh(S * -1j) 59 | 60 | P = V.conj().T @ P 61 | B_orig = B 62 | B = V.conj().T @ B 63 | return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig 64 | 65 | 66 | def log_step_initializer(dt_min=0.001, dt_max=0.1): 67 | """ 68 | Initialize the learnable timescale Delta by sampling 69 | uniformly between dt_min and dt_max. 70 | 71 | :params dt_min: float32, minimum value of log timestep 72 | :params dt_max: float32, maximum value of log timestep 73 | :returns: init function 74 | """ 75 | def init(key, shape): 76 | return random.uniform(key, shape) * ( 77 | np.log(dt_max) - np.log(dt_min) 78 | ) + np.log(dt_min) 79 | 80 | return init 81 | 82 | 83 | def init_log_steps(key, input): 84 | """ 85 | Initialize an array of learnable timescale parameters 86 | 87 | :params key: jax random 88 | :params input: tuple containing the array shape H and 89 | dt_min and dt_max 90 | :returns: initialized array of timescales (float32): (H,) 91 | """ 92 | H, dt_min, dt_max = input 93 | log_steps = [] 94 | for i in range(H): 95 | key, skey = random.split(key) 96 | log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,)) 97 | log_steps.append(log_step) 98 | 99 | return np.array(log_steps) 100 | 101 | 102 | def init_VinvB(init_fun, rng, shape, Vinv): 103 | """ 104 | Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. 105 | Note we will parameterize this with two different matrices for complex numbers. 106 | 107 | :params init_fun: function, the initialization function to use, e.g. lecun_normal() 108 | :params rng: jax random key to be used with init function. 109 | :params shape: tuple, desired shape (P,H) 110 | :params Vinv: complex64, the inverse eigenvectors used for initialization 111 | :returns: B_tilde (complex64) of shape (P,H,2) 112 | """ 113 | B = init_fun(rng, shape) 114 | VinvB = Vinv @ B 115 | VinvB_real = VinvB.real 116 | VinvB_imag = VinvB.imag 117 | return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) 118 | 119 | 120 | def trunc_standard_normal(key, shape): 121 | """ 122 | Sample C with a truncated normal distribution with standard deviation 1. 123 | 124 | :params key: jax random key 125 | :params shape: tuple, desired shape (H,P, _) 126 | :returns: sampled C matrix (float32) of shape (H,P,2) (for complex parameterization) 127 | """ 128 | H, P, _ = shape 129 | Cs = [] 130 | for i in range(H): 131 | key, skey = random.split(key) 132 | C = lecun_normal()(skey, shape=(1, P, 2)) 133 | Cs.append(C) 134 | return np.array(Cs)[:, 0] 135 | 136 | 137 | def init_CV(init_fun, rng, shape, V): 138 | """ 139 | Initialize C_tilde=CV. First sample C. Then compute CV. 140 | Note we will parameterize this with two different matrices for complex numbers. 141 | 142 | :params init_fun: function, the initialization function to use, e.g. lecun_normal() 143 | :params rng: jax random key to be used with init function. 144 | :params shape: tuple, desired shape (H,P) 145 | :params V: complex64, the eigenvectors used for initialization 146 | :returns: C_tilde (complex64) of shape (H,P,2) 147 | """ 148 | C_ = init_fun(rng, shape) 149 | C = C_[..., 0] + 1j * C_[..., 1] 150 | CV = C @ V 151 | CV_real = CV.real 152 | CV_imag = CV.imag 153 | return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1) 154 | -------------------------------------------------------------------------------- /event_ssm/train_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax 3 | import jax.numpy as jnp 4 | from jaxtyping import Array 5 | from typing import Any, Dict 6 | import random 7 | from flax.training import train_state 8 | import optax 9 | from functools import partial 10 | 11 | 12 | class TrainState(train_state.TrainState): 13 | key: Array 14 | model_state: Dict 15 | 16 | 17 | def training_step( 18 | train_state: TrainState, 19 | batch: Array, 20 | dropout_key: Array, 21 | distributed: bool = False 22 | ): 23 | """ 24 | Conducts a single training step on a batch of data. 25 | 26 | :param train_state: a Flax TrainState that carries the parameters, optimizer states etc 27 | :param batch: the data consisting of [data, target, integration_timesteps, lengths] 28 | :param distributed: If True, apply reduce operations like psum, pmean etc 29 | :return: train_state, metrics 30 | """ 31 | inputs, targets, integration_timesteps, lengths = batch 32 | 33 | def loss_fn(params): 34 | logits, updates = train_state.apply_fn( 35 | {'params': params, **train_state.model_state}, 36 | inputs, integration_timesteps, lengths, 37 | True, 38 | rngs={'dropout': dropout_key}, 39 | mutable=['batch_stats'] 40 | ) 41 | 42 | loss = optax.softmax_cross_entropy(logits, targets) 43 | loss = loss.mean() 44 | 45 | return loss, (logits, updates) 46 | 47 | (loss, (logits, batch_updates)), grads = jax.value_and_grad(loss_fn, has_aux=True)(train_state.params) 48 | 49 | preds = jnp.argmax(logits, axis=-1) 50 | targets = jnp.argmax(targets, axis=-1) 51 | accuracy = (preds == targets).mean() 52 | 53 | if distributed: 54 | grads = jax.lax.pmean(grads, axis_name='data') 55 | loss = jax.lax.pmean(loss, axis_name='data') 56 | accuracy = jax.lax.pmean(accuracy, axis_name='data') 57 | 58 | train_state = train_state.apply_gradients(grads=grads) 59 | train_state = train_state.replace(model_state=batch_updates) 60 | 61 | return train_state, {'loss': loss, 'accuracy': accuracy} 62 | 63 | 64 | def evaluation_step( 65 | train_state: TrainState, 66 | batch: Array, 67 | distributed: bool = False 68 | ): 69 | """ 70 | Conducts a single evaluation step on a batch of data. 71 | 72 | :param train_state: a Flax TrainState that carries the parameters, optimizer states etc 73 | :param batch: the data consisting of [data, target] 74 | :param distributed: If True, apply reduce operations like psum, pmean etc 75 | :return: train_state, metrics 76 | """ 77 | inputs, targets, integration_timesteps, lengths = batch 78 | logits = train_state.apply_fn( 79 | {'params': train_state.params, **train_state.model_state}, 80 | inputs, integration_timesteps, lengths, 81 | False, 82 | ) 83 | loss = optax.softmax_cross_entropy(logits, targets) 84 | loss = loss.mean() 85 | preds = jnp.argmax(logits, axis=-1) 86 | targets = jnp.argmax(targets, axis=-1) 87 | accuracy = (preds == targets).mean() 88 | 89 | if distributed: 90 | loss = jax.lax.pmean(loss, axis_name='data') 91 | accuracy = jax.lax.pmean(accuracy, axis_name='data') 92 | 93 | return train_state, {'loss': loss, 'accuracy': accuracy} 94 | 95 | 96 | def map_nested_fn(fn): 97 | """ 98 | Recursively apply `fn to the key-value pairs of a nested dict / pytree. 99 | We use this for some of the optax definitions below. 100 | """ 101 | 102 | def map_fn(nested_dict): 103 | return { 104 | k: (map_fn(v) if hasattr(v, "keys") else fn(k, v)) 105 | for k, v in nested_dict.items() 106 | } 107 | 108 | return map_fn 109 | 110 | 111 | def map_nested_fn_with_keyword(keyword_1, keyword_2): 112 | '''labels all the leaves that are descendants of keyword_1 with keyword 1, 113 | else label the leaf with keyword_2''' 114 | 115 | def map_fn(nested_dict): 116 | output_dict = {} 117 | for k, v in nested_dict.items(): 118 | if isinstance(v, dict): 119 | if k == keyword_1: 120 | output_dict[k] = map_fn_2(v) 121 | else: 122 | output_dict[k] = map_fn(v) 123 | else: 124 | if k == keyword_1: 125 | output_dict[k] = keyword_1 126 | else: 127 | output_dict[k] = keyword_2 128 | return output_dict 129 | 130 | def map_fn_2(nested_dict): 131 | output_dict = {} 132 | for k, v in nested_dict.items(): 133 | if isinstance(v, dict): 134 | output_dict[k] = map_fn_2(v) 135 | else: 136 | output_dict[k] = keyword_1 137 | return output_dict 138 | 139 | return map_fn 140 | 141 | 142 | def seed_all(seed): 143 | random.seed(seed) 144 | np.random.seed(seed) 145 | 146 | 147 | def get_first_device(x): 148 | x = jax.tree_util.tree_map(lambda a: a[0], x) 149 | return jax.device_get(x) 150 | 151 | 152 | def print_model_size(params, name=''): 153 | fn_is_complex = lambda x: x.dtype in [np.complex64, np.complex128] 154 | param_sizes = map_nested_fn(lambda k, param: param.size * (2 if fn_is_complex(param) else 1))(params) 155 | total_params_size = sum(jax.tree_leaves(param_sizes)) 156 | print('[*] Model parameter count:', total_params_size) 157 | 158 | 159 | def get_learning_rate_fn(lr, total_steps, warmup_steps, schedule, **kwargs): 160 | if schedule == 'cosine': 161 | learning_rate_fn = optax.warmup_cosine_decay_schedule( 162 | init_value=0., 163 | peak_value=lr, 164 | warmup_steps=warmup_steps, 165 | decay_steps=total_steps 166 | ) 167 | elif schedule == 'constant': 168 | learning_rate_fn = optax.join_schedules([ 169 | optax.linear_schedule( 170 | init_value=0., 171 | end_value=lr, 172 | transition_steps=warmup_steps 173 | ), 174 | optax.constant_schedule(lr) 175 | ], [warmup_steps]) 176 | else: 177 | raise ValueError(f'Unknown schedule: {schedule}') 178 | 179 | return learning_rate_fn 180 | 181 | 182 | def get_optimizer(opt_config): 183 | 184 | ssm_lrs = ["B", "Lambda_re", "Lambda_im"] 185 | ssm_fn = map_nested_fn( 186 | lambda k, _: "ssm" 187 | if k in ssm_lrs 188 | else "regular" 189 | ) 190 | learning_rate_fn = partial( 191 | get_learning_rate_fn, 192 | total_steps=opt_config.total_steps, 193 | warmup_steps=opt_config.warmup_steps, 194 | schedule=opt_config.schedule 195 | ) 196 | 197 | def optimizer(learning_rate): 198 | tx = optax.multi_transform( 199 | { 200 | "ssm": optax.inject_hyperparams(partial( 201 | optax.adamw, 202 | b1=0.9, b2=0.999, 203 | weight_decay=opt_config.ssm_weight_decay 204 | ))(learning_rate=learning_rate_fn(lr=learning_rate)), 205 | "regular": optax.adamw( 206 | learning_rate=learning_rate_fn(lr=learning_rate * opt_config.lr_factor), 207 | b1=0.9, b2=0.999, 208 | weight_decay=opt_config.weight_decay), 209 | }, 210 | ssm_fn, 211 | ) 212 | if opt_config.get('accumulation_steps', False): 213 | print(f"[*] Using gradient accumulation with {opt_config.accumulation_steps} steps") 214 | tx = optax.MultiSteps(tx, every_k_schedule=opt_config.accumulation_steps) 215 | return tx 216 | 217 | return optimizer(opt_config.ssm_lr) 218 | 219 | 220 | def init_model_state(rng_key, model, inputs, steps, lengths, opt_config): 221 | """ 222 | Initialize the training state. 223 | 224 | :param rng_key: a PRNGKey 225 | :param model: the Flax model to train 226 | :param inputs: dummy input data 227 | :param steps: dummy integration timesteps 228 | :param lengths: dummy number of events 229 | :param opt_config: a dictionary containing the optimizer configuration 230 | :return: a TrainState object 231 | """ 232 | init_key, dropout_key = jax.random.split(rng_key) 233 | variables = model.init( 234 | {"params": init_key, 235 | "dropout": dropout_key}, 236 | inputs, steps, lengths, True 237 | ) 238 | params = variables.pop('params') 239 | model_state = variables 240 | print_model_size(params) 241 | 242 | tx = get_optimizer(opt_config) 243 | return TrainState.create( 244 | apply_fn=model.apply, 245 | params=params, 246 | tx=tx, 247 | key=dropout_key, 248 | model_state=model_state 249 | ) 250 | -------------------------------------------------------------------------------- /event_ssm/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import json 4 | import sys 5 | import wandb 6 | from collections import defaultdict, OrderedDict 7 | from omegaconf import OmegaConf as om 8 | from omegaconf import DictConfig 9 | import jax.numpy as jnp 10 | import jax.random 11 | from jaxtyping import Array 12 | from typing import Callable, Dict, Optional, Iterator, Any 13 | from flax.training.train_state import TrainState 14 | from flax.training import checkpoints 15 | from flax import jax_utils 16 | from functools import partial 17 | 18 | 19 | @partial(jax.jit, static_argnums=(1,)) 20 | def reshape_batch_per_device(x, num_devices): 21 | return jax.tree_util.tree_map(partial(reshape_array_per_device, num_devices=num_devices), x) 22 | 23 | 24 | def reshape_array_per_device(x, num_devices): 25 | batch_size_per_device, ragged = divmod(x.shape[0], num_devices) 26 | if ragged: 27 | msg = "batch size must be divisible by device count, got {} and {}." 28 | raise ValueError(msg.format(x.shape[0], num_devices)) 29 | return x.reshape((num_devices, batch_size_per_device, ) + (x.shape[1:])) 30 | 31 | 32 | class TrainerModule: 33 | """ 34 | Handles training and logging of models. Most of the boilerplate code is hidden from the user. 35 | """ 36 | def __init__( 37 | self, 38 | train_state: TrainState, 39 | training_step_fn: Callable, 40 | evaluation_step_fn: Callable, 41 | world_size: int, 42 | config: DictConfig, 43 | ): 44 | """ 45 | 46 | :param train_state: A TrainState object that contains the model parameters, optimizer states etc. 47 | :param training_step_fn: A function that takes the train_state and a batch of data and returns the updated train_state and metrics. 48 | :param evaluation_step_fn: A function that takes the train_state and a batch of data and returns the updated train_state and metrics. 49 | :param world_size: Number of devices to run the training on. 50 | :param config: The configuration of the training run. 51 | """ 52 | super().__init__() 53 | self.train_state = train_state 54 | self.train_step = training_step_fn 55 | self.eval_step = evaluation_step_fn 56 | 57 | self.world_size = world_size 58 | self.log_config = config.logging 59 | self.epoch_idx = 0 60 | self.num_epochs = config.training.num_epochs 61 | self.best_eval_metrics = {} 62 | 63 | # logger details 64 | self.log_dir = os.path.join(self.log_config.log_dir) 65 | print('[*] Logging to', self.log_dir) 66 | 67 | if not os.path.isdir(self.log_dir): 68 | os.makedirs(self.log_dir) 69 | if not os.path.isdir(os.path.join(self.log_dir, 'metrics')): 70 | os.makedirs(os.path.join(self.log_dir, 'metrics')) 71 | if not os.path.isdir(os.path.join(self.log_dir, 'checkpoints')): 72 | os.makedirs(os.path.join(self.log_dir, 'checkpoints')) 73 | 74 | num_parameters = int(sum( 75 | [arr.size for arr in jax.tree_flatten(self.train_state.params)[0] 76 | if isinstance(arr, Array)] 77 | ) / self.world_size) 78 | print("[*] Number of model parameters:", num_parameters) 79 | 80 | if self.log_config.wandb: 81 | wandb.init( 82 | # set the wandb project where this run will be logged 83 | dir=self.log_config.log_dir, 84 | project=self.log_config.project, 85 | entity=self.log_config.entity, 86 | config=om.to_container(config, resolve=True)) 87 | wandb.config.update({'SLURM_JOB_ID': os.getenv('SLURM_JOB_ID')}) 88 | 89 | # log number of parameters 90 | wandb.run.summary['Num parameters'] = num_parameters 91 | wandb.define_metric(self.log_config.summary_metric, summary='max') 92 | 93 | def train_model( 94 | self, 95 | train_loader: Iterator, 96 | val_loader: Iterator, 97 | dropout_key: Array, 98 | test_loader: Optional[Iterator] = None, 99 | ) -> Dict[str, Any]: 100 | """ 101 | Trains a model on a dataset. 102 | 103 | :param train_loader: Data loader of the training set. 104 | :param val_loader: Data loader of the validation set. 105 | :param dropout_key: Random key for dropout. 106 | :param test_loader: Data loader of the test set. 107 | :return: A dictionary of the best evaluation metrics. 108 | """ 109 | 110 | # Prepare training loop 111 | self.on_training_start() 112 | 113 | for epoch_idx in range(1, self.num_epochs+1): 114 | self.epoch_idx = epoch_idx 115 | 116 | # run training step for this epoch 117 | train_metrics = self.train_epoch(train_loader, dropout_key) 118 | 119 | self.on_training_epoch_end(train_metrics) 120 | 121 | # Validation every N epochs 122 | eval_metrics = self.eval_model( 123 | val_loader, 124 | log_prefix='Performance/Validation', 125 | ) 126 | 127 | self.on_validation_epoch_end(eval_metrics) 128 | 129 | if self.log_config.wandb: 130 | from optax import MultiStepsState 131 | wandb_metrics = {'Performance/epoch': epoch_idx} 132 | wandb_metrics.update(train_metrics) 133 | wandb_metrics.update(eval_metrics) 134 | if isinstance(self.train_state.opt_state, MultiStepsState): 135 | lr = self.train_state.opt_state.inner_opt_state.inner_states['ssm'].inner_state.hyperparams['learning_rate'].item() 136 | else: 137 | lr = self.train_state.opt_state.inner_states['ssm'].inner_state.hyperparams['learning_rate'].item() 138 | wandb_metrics['learning rate'] = lr 139 | wandb.log(wandb_metrics) 140 | 141 | # Test best model if possible 142 | if test_loader is not None: 143 | self.load_model() 144 | test_metrics = self.eval_model( 145 | test_loader, 146 | log_prefix='Performance/Test', 147 | ) 148 | self.save_metrics('test', test_metrics) 149 | self.best_eval_metrics.update(test_metrics) 150 | 151 | if self.log_config.wandb: 152 | wandb.log(test_metrics) 153 | 154 | print('-' * 89) 155 | print('| End of Training |') 156 | print('| Test Metrics |', 157 | ' | '.join([f"{k.split('/')[1].replace('Test','')}: {v:5.2f}" for k, v in test_metrics.items() if 'Test' in k])) 158 | print('-' * 89) 159 | 160 | return self.best_eval_metrics 161 | 162 | def train_epoch(self, train_loader: Iterator, dropout_key) -> Dict[str, Any]: 163 | """ 164 | Trains the model on one epoch of the training set. 165 | 166 | :param train_loader: Data loader of the training set. 167 | :param dropout_key: Random key for dropout. 168 | :return: A dictionary of the training metrics. 169 | """ 170 | 171 | # Train model for one epoch, and log avg loss and accuracy 172 | metrics = defaultdict(float) 173 | running_metrics = defaultdict(float) 174 | num_batches = 0 175 | num_train_batches = len(train_loader) 176 | start_time = time.time() 177 | epoch_start_time = start_time 178 | 179 | # set up intra epoch logging 180 | log_interval = self.log_config.interval 181 | 182 | for i, batch in enumerate(train_loader): 183 | num_batches += 1 184 | 185 | # skip batches with empty sequences which might randomly occur due to data augmentation 186 | _, _, _, lengths = batch 187 | if jnp.any(lengths == 0): 188 | continue 189 | 190 | if self.world_size > 1: 191 | step_key, dropout_key = jax.vmap(jax.random.split, in_axes=0, out_axes=1)(dropout_key) 192 | step_key = jax.vmap(jax.random.fold_in)(step_key, jnp.arange(self.world_size)) 193 | batch = reshape_batch_per_device(batch, self.world_size) 194 | else: 195 | step_key, dropout_key = jax.random.split(dropout_key) 196 | 197 | self.train_state, step_metrics = self.train_step(self.train_state, batch, step_key) 198 | 199 | # exit from training if loss is nan 200 | if jnp.isnan(step_metrics['loss']).any(): 201 | print("EXITING TRAINING DUE TO NAN LOSS") 202 | break 203 | 204 | # record metrics 205 | for key in step_metrics: 206 | metrics['Performance/Training ' + key] += step_metrics[key] 207 | running_metrics['Performance/Training ' + key] += step_metrics[key] 208 | 209 | # print metrics to terminal 210 | if (i + 1) % log_interval == 0: 211 | elapsed = time.time() - start_time 212 | start_time = time.time() 213 | print(f'| epoch {self.epoch_idx} | {i + 1}/{num_train_batches} batches | ms/batch {elapsed * 1000 / log_interval:5.2f} |', 214 | ' | '.join([f'{k}: {jnp.mean(v).item() / log_interval:5.2f}' for k, v in running_metrics.items()])) 215 | for key in step_metrics: 216 | running_metrics['Performance/Training ' + key] = 0 217 | 218 | metrics = {key: jnp.mean(metrics[key] / num_batches).item() for key in metrics} 219 | metrics['epoch_time'] = time.time() - epoch_start_time 220 | return metrics 221 | 222 | def eval_model( 223 | self, 224 | data_loader: Iterator, 225 | log_prefix: Optional[str] = '', 226 | ) -> Dict[str, Any]: 227 | """ 228 | Evaluates the model on a dataset. 229 | 230 | :param data_loader: Data loader of the dataset. 231 | :param log_prefix: Prefix to add to the keys of the logged metrics such as "Best" or "Validation". 232 | :return: A dictionary of the evaluation metrics. 233 | """ 234 | 235 | # Test model on all images of a data loader and return avg loss 236 | metrics = defaultdict(float) 237 | num_batches = 0 238 | 239 | for i, batch in enumerate(iter(data_loader)): 240 | 241 | if self.world_size > 1: 242 | batch = reshape_batch_per_device(batch, self.world_size) 243 | 244 | self.train_state, step_metrics = self.eval_step(self.train_state, batch) 245 | 246 | for key in step_metrics: 247 | metrics[key] += step_metrics[key] 248 | num_batches += 1 249 | 250 | prefix = log_prefix + ' ' if log_prefix else '' 251 | metrics = {(prefix + key): jnp.mean(metrics[key] / num_batches).item() for key in metrics} 252 | return metrics 253 | 254 | def is_new_model_better(self, new_metrics: Dict[str, Any], old_metrics: Dict[str, Any]) -> bool: 255 | """ 256 | Compares two sets of evaluation metrics to decide whether the 257 | new model is better than the previous ones or not. 258 | 259 | :params new_metrics: A dictionary of the evaluation metrics of the new model. 260 | :params old_metrics: A dictionary of the evaluation metrics of the previously 261 | best model, i.e. the one to compare to. 262 | :return: True if the new model is better than the old one, and False otherwise. 263 | """ 264 | if len(old_metrics) == 0: 265 | return True 266 | for key, is_larger in [('val/val_metric', False), ('Performance/Validation accuracy', True), ('Performance/Validation loss', False)]: 267 | if key in new_metrics: 268 | if is_larger: 269 | return new_metrics[key] > old_metrics[key] 270 | else: 271 | return new_metrics[key] < old_metrics[key] 272 | assert False, f'No known metrics to log on: {new_metrics}' 273 | 274 | def save_metrics(self, filename: str, metrics: Dict[str, Any]): 275 | """ 276 | Saves a dictionary of metrics to file. Can be used as a textual 277 | representation of the validation performance for checking in the terminal. 278 | 279 | :param filename: The name of the file to save the metrics to. 280 | :param metrics: A dictionary of the metrics to save. 281 | """ 282 | with open(os.path.join(self.log_dir, f'metrics/{filename}.json'), 'w') as f: 283 | json.dump(metrics, f, indent=4) 284 | 285 | def save_model(self): 286 | """ 287 | Saves the model to a file. The model is saved in the log directory. 288 | """ 289 | if self.world_size > 1: 290 | state = jax_utils.unreplicate(self.train_state) 291 | else: 292 | state = self.train_state 293 | checkpoints.save_checkpoint( 294 | ckpt_dir=os.path.abspath(os.path.join(self.log_dir, 'checkpoints')), 295 | target=state, 296 | step=state.step, 297 | overwrite=True, 298 | keep=1 299 | ) 300 | del state 301 | 302 | def load_model(self): 303 | """ 304 | Loads the model from a file. The model is loaded from the log directory. 305 | """ 306 | if self.world_size > 1: 307 | state = jax_utils.unreplicate(self.train_state) 308 | raw_restored = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(self.log_dir, 'checkpoints'), target=state) 309 | self.train_state = jax_utils.replicate(raw_restored) 310 | del state 311 | else: 312 | self.train_state = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(self.log_dir, 'checkpoints'), target=self.train_state) 313 | 314 | def on_training_start(self): 315 | """ 316 | Method called before training is started. Can be used for additional 317 | initialization operations etc. 318 | """ 319 | pass 320 | 321 | def on_training_epoch_end(self, train_metrics): 322 | """ 323 | Method called at the end of each training epoch. Can be used for additional 324 | logging or similar. 325 | """ 326 | print('-' * 89) 327 | print(f"| end of epoch {self.epoch_idx:3d} | time per epoch: {train_metrics['epoch_time']:5.2f}s |") 328 | print('| Train Metrics |', ' | '.join( 329 | [f"{k.split('/')[1].replace('Training ', '')}: {v:5.2f}" for k, v in train_metrics.items() if 330 | 'Train' in k])) 331 | 332 | # check metrics for nan values and possibly exit training 333 | if jnp.isnan(train_metrics['Performance/Training loss']).item(): 334 | print("EXITING TRAINING DUE TO NAN LOSS") 335 | sys.exit(1) 336 | 337 | def on_validation_epoch_end(self, eval_metrics: Dict[str, Any]): 338 | """ 339 | Method called at the end of each validation epoch. Can be used for additional 340 | logging and evaluation. 341 | 342 | Args: 343 | eval_metrics: A dictionary of the validation metrics. New metrics added to 344 | this dictionary will be logged as well. 345 | """ 346 | print('| Eval Metrics |', ' | '.join( 347 | [f"{k.split('/')[1].replace('Validation ', '')}: {v:5.2f}" for k, v in eval_metrics.items() if 348 | 'Validation' in k])) 349 | print('-' * 89) 350 | 351 | self.save_metrics(f'eval_epoch_{str(self.epoch_idx).zfill(3)}', eval_metrics) 352 | 353 | # Save best model 354 | if self.is_new_model_better(eval_metrics, self.best_eval_metrics): 355 | self.best_eval_metrics = eval_metrics 356 | self.save_model() 357 | self.save_metrics('best_eval', eval_metrics) 358 | -------------------------------------------------------------------------------- /event_ssm/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Identity: 5 | def __call__(self, events): 6 | return events 7 | 8 | 9 | class CropEvents: 10 | """Crops event stream to a specified number of events 11 | 12 | Parameters: 13 | num_events (int): number of events to keep 14 | """ 15 | 16 | def __init__(self, num_events): 17 | self.num_events = num_events 18 | 19 | def __call__(self, events): 20 | if self.num_events >= len(events): 21 | return events 22 | else: 23 | start = np.random.randint(0, len(events) - self.num_events) 24 | return events[start:start + self.num_events] 25 | 26 | 27 | class Jitter1D: 28 | """ 29 | Apply random jitter to event coordinates 30 | Parameters: 31 | max_roll (int): maximum number of pixels to roll by 32 | """ 33 | def __init__(self, sensor_size, var): 34 | self.sensor_size = sensor_size 35 | self.var = var 36 | 37 | def __call__(self, events): 38 | # roll x, y coordinates by a random amount 39 | shift = np.random.normal(0, self.var, len(events)).astype(np.int32) 40 | events['x'] += shift 41 | # remove events who got shifted out of the sensor size 42 | mask = (events['x'] >= 0) & (events['x'] < self.sensor_size[0]) 43 | events = events[mask] 44 | return events 45 | 46 | 47 | class Roll: 48 | """ 49 | Roll event x, y coordinates by a random amount 50 | 51 | Parameters: 52 | max_roll (int): maximum number of pixels to roll by 53 | """ 54 | def __init__(self, sensor_size, p, max_roll): 55 | self.sensor_size = sensor_size 56 | self.max_roll = max_roll 57 | self.p = p 58 | 59 | def __call__(self, events): 60 | if np.random.rand() > self.p: 61 | return events 62 | # roll x, y coordinates by a random amount 63 | roll_x = np.random.randint(-self.max_roll, self.max_roll) 64 | roll_y = np.random.randint(-self.max_roll, self.max_roll) 65 | events['x'] += roll_x 66 | events['y'] += roll_y 67 | # remove events who got shifted out of the sensor size 68 | mask = (events['x'] >= 0) & (events['x'] < self.sensor_size[0]) & (events['y'] >= 0) & (events['y'] < self.sensor_size[1]) 69 | events = events[mask] 70 | return events 71 | 72 | 73 | class Rotate: 74 | """ 75 | Rotate event x, y coordinates by a random angle 76 | """ 77 | def __init__(self, sensor_size, p, max_angle): 78 | self.p = p 79 | self.sensor_size = sensor_size 80 | self.max_angle = 2 * np.pi * max_angle / 360 81 | 82 | def __call__(self, events): 83 | if np.random.rand() > self.p: 84 | return events 85 | # rotate x, y coordinates by a random angle 86 | angle = np.random.uniform(-self.max_angle, self.max_angle) 87 | x = events['x'] - self.sensor_size[0] / 2 88 | y = events['y'] - self.sensor_size[1] / 2 89 | x_new = x * np.cos(angle) - y * np.sin(angle) 90 | y_new = x * np.sin(angle) + y * np.cos(angle) 91 | events['x'] = (x_new + self.sensor_size[0] / 2).astype(np.int32) 92 | events['y'] = (y_new + self.sensor_size[1] / 2).astype(np.int32) 93 | # clip to original range 94 | events['x'] = np.clip(events['x'], 0, self.sensor_size[0]) 95 | events['y'] = np.clip(events['y'], 0, self.sensor_size[1]) 96 | return events 97 | 98 | 99 | class Scale: 100 | """ 101 | Scale event x, y coordinates by a random factor 102 | """ 103 | def __init__(self, sensor_size, p, max_scale): 104 | assert max_scale >= 1 105 | self.p = p 106 | self.sensor_size = sensor_size 107 | self.max_scale = max_scale 108 | 109 | def __call__(self, events): 110 | if np.random.rand() > self.p: 111 | return events 112 | # scale x, y coordinates by a random factor 113 | scale = np.random.uniform(1/self.max_scale, self.max_scale) 114 | x = events['x'] - self.sensor_size[0] / 2 115 | y = events['y'] - self.sensor_size[1] / 2 116 | x_new = x * scale 117 | y_new = y * scale 118 | events['x'] = (x_new + self.sensor_size[0] / 2).astype(np.int32) 119 | events['y'] = (y_new + self.sensor_size[1] / 2).astype(np.int32) 120 | # remove events who got shifted out of the sensor size 121 | mask = (events['x'] >= 0) & (events['x'] < self.sensor_size[0]) & (events['y'] >= 0) & (events['y'] < self.sensor_size[1]) 122 | events = events[mask] 123 | return events 124 | 125 | 126 | class DropEventChunk: 127 | """ 128 | Randomly drop a chunk of events 129 | """ 130 | def __init__(self, p, max_drop_size): 131 | self.drop_prob = p 132 | self.max_drop_size = max_drop_size 133 | 134 | def __call__(self, events): 135 | max_drop_events = self.max_drop_size * len(events) 136 | if np.random.rand() < self.drop_prob: 137 | drop_size = np.random.randint(1, max_drop_events) 138 | start = np.random.randint(0, len(events) - drop_size) 139 | events = np.delete(events, slice(start, start + drop_size), axis=0) 140 | return events 141 | 142 | 143 | class OneHotLabels: 144 | """ 145 | Convert integer labels to one-hot encoding 146 | """ 147 | def __init__(self, num_classes): 148 | self.num_classes = num_classes 149 | 150 | def __call__(self, label): 151 | return np.eye(self.num_classes)[label] 152 | 153 | 154 | def cut_mix_augmentation(events, targets): 155 | """ 156 | Cut and mix two event streams by a random event chunk. Input is a list of event streams. 157 | 158 | Args: 159 | events (dict): batch of event streams of shape (batch_size, num_events, 4) 160 | max_num_events (int): maximum number of events to mix 161 | """ 162 | # get the total time of all events 163 | lengths = np.array([e.shape[0] for e in events]) 164 | 165 | # get fraction of the event-stream to cut 166 | cut_size = np.random.randint(low=1, high=lengths) 167 | start_event = np.random.randint(low=0, high=lengths - cut_size) 168 | 169 | # a random permutation to mix the events 170 | rand_index = np.random.permutation(len(events)) 171 | 172 | mixed_events = [] 173 | mixed_targets = [] 174 | 175 | # cut events from b and mix them with events from a 176 | for i in range(len(events)): 177 | events_b = events[rand_index[i]][start_event[rand_index[i]]:start_event[rand_index[i]] + cut_size[rand_index[i]]] 178 | mask_a = (events[i]['t'] >= events_b['t'][0]) & (events[i]['t'] <= events_b['t'][-1]) 179 | events_a = events[i][~mask_a] 180 | 181 | # mix and sort events 182 | new_events = np.concatenate([events_a, events_b]) 183 | new_events = new_events[np.argsort(new_events['t'])] 184 | 185 | # mix targets 186 | lam = events_b.shape[0] / new_events.shape[0] 187 | assert 0 <= lam <= 1, f'lam should be between 0 and 1, but got {lam} {cut_size[rand_index[i]]} {events_a.shape[0]} {events_b.shape[0]}' 188 | 189 | # append mixed events and targets 190 | mixed_events.append(new_events) 191 | mixed_targets.append(targets[i] * (1 - lam) + targets[rand_index[i]] * lam) 192 | 193 | return mixed_events, mixed_targets 194 | 195 | 196 | def cut_mix_augmentation_time(events, targets): 197 | """ 198 | Cut and mix two event streams by a random event chunk. Input is a list of event streams. 199 | 200 | :param events: batch of event streams of shape (batch_size, num_events, 4) 201 | :param targets: batch of targets of shape (batch_size, num_classes) 202 | 203 | :return: mixed events, mixed targets 204 | """ 205 | # get the total time of all events 206 | lengths = np.array([e['t'][-1] - e['t'][0] for e in events], dtype=np.float32) 207 | 208 | # get fraction of the event-stream to cut 209 | cut_size = np.random.uniform(low=0, high=lengths) 210 | start_time = np.random.uniform(low=0, high=lengths - cut_size) 211 | 212 | # a random permutation to mix the events 213 | rand_index = np.random.permutation(len(events)) 214 | 215 | mixed_events = [] 216 | mixed_targets = [] 217 | 218 | # cut events from b and mix them with events from a 219 | for i in range(len(events)): 220 | start, end = start_time[rand_index[i]], start_time[rand_index[i]] + cut_size[rand_index[i]] 221 | mask_a = (events[i]['t'] >= start) & (events[i]['t'] <= end) 222 | mask_b = (events[rand_index[i]]['t'] >= start) & (events[rand_index[i]]['t'] <= end) 223 | 224 | # mix events 225 | new_events = np.concatenate([events[i][~mask_a], events[rand_index[i]][mask_b]]) 226 | 227 | # avoid the case that the new events are empty 228 | if len(new_events) == 0: 229 | mixed_events.append(events[i]) 230 | mixed_targets.append(targets[i]) 231 | else: 232 | # sort events 233 | new_events = new_events[np.argsort(new_events['t'])] 234 | mixed_events.append(new_events) 235 | 236 | # mix targets 237 | new_length = new_events['t'][-1] - new_events['t'][0] 238 | if len(events[rand_index[i]]['t'][mask_b]) == 0: 239 | cut_length = 0 240 | else: 241 | cut_length = events[rand_index[i]]['t'][mask_b][-1] - events[rand_index[i]]['t'][mask_b][0] 242 | lam = cut_length / new_length 243 | assert 0 <= lam <= 1, f'lam should be between 0 and 1, but got {lam} {new_length} {cut_size[rand_index[i]]} {start} {end}' 244 | mixed_targets.append(targets[i] * (1 - lam) + targets[rand_index[i]] * lam) 245 | 246 | return mixed_events, mixed_targets 247 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cpu 2 | jax[cuda12] 3 | flax 4 | jaxtyping 5 | optax 6 | torch 7 | torchvision 8 | torchaudio 9 | wandb 10 | tonic 11 | hydra-core -------------------------------------------------------------------------------- /run_evaluation.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import OmegaConf as om 3 | from omegaconf import DictConfig, open_dict 4 | from functools import partial 5 | 6 | import jax.random 7 | import jax.numpy as jnp 8 | import optax 9 | from flax.training import checkpoints 10 | 11 | from event_ssm.dataloading import Datasets 12 | from event_ssm.ssm import init_S5SSM 13 | from event_ssm.seq_model import BatchClassificationModel 14 | 15 | 16 | def setup_evaluation(cfg: DictConfig): 17 | num_devices = jax.local_device_count() 18 | assert cfg.checkpoint, "No checkpoint directory provided. Use checkpoint= to specify a checkpoint." 19 | 20 | # load task specific data 21 | create_dataset_fn = Datasets[cfg.task.name] 22 | 23 | # Create dataset... 24 | print("[*] Loading dataset...") 25 | train_loader, val_loader, test_loader, data = create_dataset_fn( 26 | cache_dir=cfg.data_dir, 27 | seed=cfg.seed, 28 | world_size=num_devices, 29 | **cfg.training 30 | ) 31 | 32 | with open_dict(cfg): 33 | # optax updates the schedule every iteration and not every epoch 34 | cfg.optimizer.total_steps = cfg.training.num_epochs * len(train_loader) // cfg.optimizer.accumulation_steps 35 | cfg.optimizer.warmup_steps = cfg.optimizer.warmup_epochs * len(train_loader) // cfg.optimizer.accumulation_steps 36 | 37 | # scale learning rate by batch size 38 | cfg.optimizer.ssm_lr = cfg.optimizer.ssm_base_lr * cfg.training.per_device_batch_size * num_devices * cfg.optimizer.accumulation_steps 39 | 40 | # load model 41 | print("[*] Creating model...") 42 | ssm_init_fn = init_S5SSM(**cfg.model.ssm_init) 43 | model = BatchClassificationModel( 44 | ssm=ssm_init_fn, 45 | num_classes=data.n_classes, 46 | num_embeddings=data.num_embeddings, 47 | **cfg.model.ssm, 48 | ) 49 | 50 | # initialize training state 51 | state = checkpoints.restore_checkpoint(cfg.checkpoint, target=None) 52 | params = state['params'] 53 | model_state = state['model_state'] 54 | 55 | return model, params, model_state, train_loader, val_loader, test_loader 56 | 57 | 58 | def evaluation_step( 59 | apply_fn, 60 | params, 61 | model_state, 62 | batch 63 | ): 64 | """ 65 | Evaluates the loss of the function passed as argument on a batch 66 | 67 | :param train_state: a Flax TrainState that carries the parameters, optimizer states etc 68 | :param batch: the data consisting of [data, target] 69 | :return: train_state, metrics 70 | """ 71 | inputs, targets, integration_timesteps, lengths = batch 72 | logits = apply_fn( 73 | 74 | {'params': params, **model_state}, 75 | inputs, integration_timesteps, lengths, 76 | False, 77 | ) 78 | 79 | loss = optax.softmax_cross_entropy(logits, targets) 80 | loss = loss.mean() 81 | preds = jnp.argmax(logits, axis=-1) 82 | targets = jnp.argmax(targets, axis=-1) 83 | accuracy = (preds == targets).mean() 84 | 85 | return {'loss': loss, 'accuracy': accuracy}, preds 86 | 87 | 88 | @hydra.main(version_base=None, config_path='configs', config_name='base') 89 | def main(config: DictConfig): 90 | print(om.to_yaml(config)) 91 | 92 | model, params, model_state, train_loader, val_loader, test_loader = setup_evaluation(cfg=config) 93 | step = partial(evaluation_step, model.apply, params, model_state) 94 | step = jax.jit(step) 95 | 96 | # run training 97 | print("[*] Running evaluation...") 98 | metrics = {} 99 | events_per_sample = [] 100 | time_per_sample = [] 101 | targets = [] 102 | predictions = [] 103 | num_batches = 0 104 | 105 | for i, batch in enumerate(test_loader): 106 | step_metrics, preds = step(batch) 107 | 108 | predictions.append(preds) 109 | targets.append(jnp.argmax(batch[1], axis=-1)) 110 | time_per_sample.append(jnp.sum(batch[2], axis=1)) 111 | events_per_sample.append(batch[3]) 112 | 113 | if not metrics: 114 | metrics = step_metrics 115 | else: 116 | for key, val in step_metrics.items(): 117 | metrics[key] += val 118 | num_batches += 1 119 | 120 | metrics = {key: jnp.mean(metrics[key] / num_batches).item() for key in metrics} 121 | 122 | print(f"[*] Test accuracy: {100 * metrics['accuracy']:.2f}%") 123 | 124 | 125 | if __name__ == '__main__': 126 | main() 127 | -------------------------------------------------------------------------------- /run_training.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import OmegaConf as om 3 | from omegaconf import DictConfig, open_dict 4 | from functools import partial 5 | import os 6 | 7 | import jax.random 8 | from flax import jax_utils 9 | from flax.training import checkpoints 10 | 11 | from event_ssm.dataloading import Datasets 12 | from event_ssm.ssm import init_S5SSM 13 | from event_ssm.seq_model import BatchClassificationModel 14 | from event_ssm.train_utils import training_step, evaluation_step, init_model_state 15 | from event_ssm.trainer import TrainerModule 16 | 17 | 18 | def setup_training(key, cfg: DictConfig): 19 | num_devices = jax.local_device_count() 20 | 21 | # load task specific data 22 | create_dataset_fn = Datasets[cfg.task.name] 23 | 24 | # Create dataset... 25 | print("[*] Loading dataset...") 26 | train_loader, val_loader, test_loader, data = create_dataset_fn( 27 | cache_dir=cfg.data_dir, 28 | seed=cfg.seed, 29 | world_size=num_devices, 30 | **cfg.training 31 | ) 32 | 33 | with open_dict(cfg): 34 | # optax updates the schedule every iteration and not every epoch 35 | cfg.optimizer.total_steps = cfg.training.num_epochs * len(train_loader) // cfg.optimizer.accumulation_steps 36 | cfg.optimizer.warmup_steps = cfg.optimizer.warmup_epochs * len(train_loader) // cfg.optimizer.accumulation_steps 37 | 38 | # scale learning rate by batch size 39 | cfg.optimizer.ssm_lr = cfg.optimizer.ssm_base_lr * cfg.training.per_device_batch_size * num_devices * cfg.optimizer.accumulation_steps 40 | 41 | # load model 42 | print("[*] Creating model...") 43 | ssm_init_fn = init_S5SSM(**cfg.model.ssm_init) 44 | model = BatchClassificationModel( 45 | ssm=ssm_init_fn, 46 | num_classes=data.n_classes, 47 | num_embeddings=data.num_embeddings, 48 | **cfg.model.ssm, 49 | ) 50 | 51 | # initialize training state 52 | print("[*] Initializing model state...") 53 | single_bsz = cfg.training.per_device_batch_size 54 | batch = next(iter(train_loader)) 55 | inputs, targets, timesteps, lengths = batch 56 | state = init_model_state(key, model, inputs[:single_bsz], timesteps[:single_bsz], lengths[:single_bsz], cfg.optimizer) 57 | 58 | if cfg.training.get('from_checkpoint', None): 59 | print(f'[*] Resuming model from {cfg.training.from_checkpoint}') 60 | state = checkpoints.restore_checkpoint(cfg.training.from_checkpoint, state) 61 | 62 | # check if multiple GPUs are available and distribute training 63 | if num_devices >= 2: 64 | print(f"[*] Running training on {num_devices} GPUs") 65 | state = jax_utils.replicate(state) 66 | train_step = jax.pmap( 67 | partial(training_step, distributed=True), 68 | axis_name='data', 69 | ) 70 | eval_step = jax.pmap( 71 | partial(evaluation_step, distributed=True), 72 | axis_name='data' 73 | ) 74 | else: 75 | train_step = jax.jit( 76 | training_step 77 | ) 78 | eval_step = jax.jit( 79 | evaluation_step 80 | ) 81 | 82 | # set up trainer module 83 | trainer = TrainerModule( 84 | train_state=state, 85 | training_step_fn=train_step, 86 | evaluation_step_fn=eval_step, 87 | world_size=num_devices, 88 | config=cfg, 89 | ) 90 | 91 | return trainer, train_loader, val_loader, test_loader 92 | 93 | 94 | @hydra.main(version_base=None, config_path='configs', config_name='base') 95 | def main(config: DictConfig): 96 | # print config and save to log directory 97 | print(om.to_yaml(config)) 98 | with open(os.path.join(config.logging.log_dir, 'config.yaml'), 'w') as f: 99 | om.save(config, f) 100 | 101 | # Set the random seed manually for reproducibility. 102 | key = jax.random.PRNGKey(config.seed) 103 | init_key, dropout_key = jax.random.split(key) 104 | 105 | if jax.local_device_count() > 1: 106 | dropout_key = jax.random.split(dropout_key, jax.local_device_count()) 107 | 108 | trainer, train_loader, val_loader, test_loader = setup_training(key=init_key, cfg=config) 109 | 110 | # run training 111 | print("[*] Running training...") 112 | trainer.train_model( 113 | train_loader=train_loader, 114 | val_loader=val_loader, 115 | test_loader=test_loader, 116 | dropout_key=dropout_key 117 | ) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup( 4 | name='Event-based-SSM', 5 | packages=['event_ssm'], 6 | version='0.1', 7 | description='Event-stream modeling with state-space models', 8 | author='Mark Schoene', 9 | author_email='mark.schoene@tu-dresden.de', 10 | ) 11 | -------------------------------------------------------------------------------- /tutorial_inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "In this notebook, we demonstrate how to evaluate a trained event-SSM model on batches of unseen data on the three tasks:\n", 8 | " \n", 9 | "1) Spiking Speech Commands\n", 10 | "2) Spiking Heidelberg Digits\n", 11 | "3) DVS128 Gesture \n", 12 | "\n", 13 | "\n", 14 | "\n", 15 | "\n", 16 | "# Setup\n", 17 | "\n", 18 | "Install and load the important modules and configuration.\n", 19 | "\n", 20 | "To install required packages, please do ``` pip3 install requirements.txt ```
\n", 21 | "\n", 22 | "Directories for loading datasets, model checkpoints and saving results are defined in the configuration file `system/local.yaml`.\n", 23 | "Please set your directories accordingly.\n", 24 | "\n", 25 | "The trained model checkpoints are [available for download](https://datashare.tu-dresden.de/s/g2dQCi792B8DqnC).\n", 26 | "\n", 27 | "## Important Libraries\n", 28 | "* [Hydra](https://hydra.cc/docs/intro/) - to manage configurations.\n", 29 | "* [Flax](https://flax.readthedocs.io/en/latest/), Neural network package built on top of [Jax](https://jax.readthedocs.io/en/latest/) - for model development\n", 30 | "* [Tonic](https://tonic.readthedocs.io/en/latest/) - for datasets" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "metadata": {}, 36 | "source": [ 37 | "import os\n", 38 | "from pathlib import Path\n", 39 | "from functools import partial\n", 40 | "from typing import Optional, TypeVar, Tuple, Union\n", 41 | "\n", 42 | "import numpy as np\n", 43 | "import seaborn as sns\n", 44 | "import matplotlib.pyplot as plt\n", 45 | "from sklearn.metrics import confusion_matrix\n", 46 | "\n", 47 | "import torch\n", 48 | "import tonic\n", 49 | "\n", 50 | "import jax\n", 51 | "import jax.numpy as jnp\n", 52 | "from flax.training import checkpoints\n", 53 | "\n", 54 | "from hydra import initialize, compose\n", 55 | "from omegaconf import OmegaConf as om\n", 56 | "\n", 57 | "from event_ssm.ssm import init_S5SSM\n", 58 | "from event_ssm.seq_model import BatchClassificationModel" 59 | ], 60 | "outputs": [], 61 | "execution_count": null 62 | }, 63 | { 64 | "cell_type": "code", 65 | "metadata": {}, 66 | "source": "os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Turn off GPU", 67 | "outputs": [], 68 | "execution_count": null 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "# Task 1 - Spiking Heidelberg Digits\n", 75 | "\n", 76 | "Spike-based version of Heidelberg digits dataset, consist of approximately 10k high-quality recordings of spoken digits ranging from zero to nine in English and German language. In total 12 speakers were included, six of which were female and six male. \n", 77 | "\n", 78 | "Two speakers were heldout exclusively for the test set. The remainder of the test set was filled with samples (5 % of the trials) from speakers also present in the training set.\n", 79 | "\n", 80 | "\n", 81 | "\n", 82 | "Ref : https://arxiv.org/pdf/1910.07407v3" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "### Step 1 : Load configuration" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "metadata": {}, 95 | "source": [ 96 | "# Load configurations\n", 97 | "with initialize(version_base=None, config_path=\"configs\"):\n", 98 | " cfg = compose(config_name=\"base.yaml\")" 99 | ], 100 | "outputs": [], 101 | "execution_count": null 102 | }, 103 | { 104 | "cell_type": "code", 105 | "metadata": {}, 106 | "source": [ 107 | "# See the model config:\n", 108 | "print(om.to_yaml(cfg.model))" 109 | ], 110 | "outputs": [], 111 | "execution_count": null 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "### Step 2 : Visualise data" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "metadata": {}, 123 | "source": [ 124 | "data = tonic.datasets.SHD(cfg.data_dir, train=False)\n", 125 | "audio_events, label = data[0]\n", 126 | "tonic.utils.plot_event_grid(audio_events)" 127 | ], 128 | "outputs": [], 129 | "execution_count": null 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "### Step 3: Load single data sample for inference" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "metadata": {}, 141 | "source": [ 142 | "DEFAULT_CACHE_DIR_ROOT = Path('./cache_dir/')\n", 143 | "DataLoader = TypeVar('DataLoader')\n", 144 | "InputType = [str, Optional[int], Optional[int]]\n", 145 | "class Data:\n", 146 | " def __init__(\n", 147 | " self,\n", 148 | " n_classes: int,\n", 149 | " num_embeddings: int\n", 150 | " ):\n", 151 | " self.n_classes = n_classes\n", 152 | " self.num_embeddings = num_embeddings" 153 | ], 154 | "outputs": [], 155 | "execution_count": null 156 | }, 157 | { 158 | "cell_type": "code", 159 | "metadata": {}, 160 | "source": [ 161 | "def event_stream_collate_fn(batch, resolution, pad_unit, no_time_information=False):\n", 162 | " # x are inputs, y are targets, z are aux data\n", 163 | " x, y, *z = zip(*batch)\n", 164 | " assert len(z) == 0\n", 165 | " batch_size_one = len(x) == 1\n", 166 | "\n", 167 | " # set labels to numpy array\n", 168 | " y = np.stack(y)\n", 169 | "\n", 170 | " # integration time steps are the difference between two consequtive time stamps\n", 171 | " if no_time_information:\n", 172 | " timesteps = [np.ones_like(e['t'][:-1]) for e in x]\n", 173 | " else:\n", 174 | " timesteps = [np.diff(e['t']) for e in x]\n", 175 | "\n", 176 | " # NOTE: since timesteps are deltas, their length is L - 1, and we have to remove the last token in the following\n", 177 | "\n", 178 | " # process tokens for single input dim (e.g. audio)\n", 179 | " if len(resolution) == 1:\n", 180 | " tokens = [e['x'][:-1].astype(np.int32) for e in x]\n", 181 | " elif len(resolution) == 2:\n", 182 | " tokens = [(e['x'][:-1] * e['y'][:-1] + np.prod(resolution) * e['p'][:-1].astype(np.int32)).astype(np.int32) for e in x]\n", 183 | " else:\n", 184 | " raise ValueError('resolution must contain 1 or 2 elements')\n", 185 | "\n", 186 | " # get padding lengths\n", 187 | " lengths = np.array([len(e) for e in timesteps], dtype=np.int32)\n", 188 | " pad_length = (lengths.max() // pad_unit) * pad_unit + pad_unit\n", 189 | "\n", 190 | " # pad tokens with -1, which results in a zero vector with embedding look-ups\n", 191 | " tokens = np.stack(\n", 192 | " [np.pad(e, (0, pad_length - len(e)), mode='constant', constant_values=-1) for e in tokens])\n", 193 | " timesteps = np.stack(\n", 194 | " [np.pad(e, (0, pad_length - len(e)), mode='constant', constant_values=0) for e in timesteps])\n", 195 | "\n", 196 | " # timesteps are in micro seconds... transform to milliseconds\n", 197 | " timesteps = timesteps / 1000\n", 198 | "\n", 199 | " if batch_size_one:\n", 200 | " lengths = lengths[None, ...]\n", 201 | "\n", 202 | " return tokens, y, timesteps, lengths" 203 | ], 204 | "outputs": [], 205 | "execution_count": null 206 | }, 207 | { 208 | "cell_type": "code", 209 | "metadata": {}, 210 | "source": [ 211 | "def event_stream_dataloader(test_data,eval_batch_size,eval_collate_fn, rng, num_workers=0):\n", 212 | " def dataloader(dset, bsz, collate_fn, shuffle, drop_last):\n", 213 | " return torch.utils.data.DataLoader(\n", 214 | " dset,\n", 215 | " batch_size=bsz,\n", 216 | " drop_last=drop_last,\n", 217 | " collate_fn=collate_fn,\n", 218 | " shuffle=shuffle,\n", 219 | " generator=rng,\n", 220 | " num_workers=num_workers\n", 221 | " )\n", 222 | " test_loader = dataloader(test_data, eval_batch_size, eval_collate_fn, shuffle=True, drop_last=False)\n", 223 | " return test_loader" 224 | ], 225 | "outputs": [], 226 | "execution_count": null 227 | }, 228 | { 229 | "cell_type": "code", 230 | "metadata": {}, 231 | "source": [ 232 | "def create_events_shd_classification_dataset(\n", 233 | " cache_dir: Union[str, Path] = DEFAULT_CACHE_DIR_ROOT,\n", 234 | " per_device_eval_batch_size: int = 64,\n", 235 | " world_size: int = 1,\n", 236 | " num_workers: int = 0,\n", 237 | " seed: int = 42,\n", 238 | " pad_unit: int = 8192,\n", 239 | " no_time_information: bool = False,\n", 240 | " **kwargs\n", 241 | ") -> Tuple[DataLoader, Data]:\n", 242 | " \"\"\"\n", 243 | " creates a view of the spiking heidelberg digits dataset\n", 244 | "\n", 245 | " :param cache_dir:\t\t (str):\t\twhere to store the dataset\n", 246 | " :param per_device_eval_batch_size:\t\t\t\t(int):\t\tEvaluation Batch size.\n", 247 | " :param seed:\t\t\t (int):\t\tSeed for shuffling data.\n", 248 | " \"\"\"\n", 249 | " print(\"[*] Generating Spiking Heidelberg Digits Classification Dataset\")\n", 250 | "\n", 251 | " if seed is not None:\n", 252 | " rng = torch.Generator()\n", 253 | " rng.manual_seed(seed)\n", 254 | " else:\n", 255 | " rng = None\n", 256 | " \n", 257 | " #target_transforms = OneHotLabels(num_classes=20)\n", 258 | " test_data = tonic.datasets.SHD(save_to=cache_dir, train=False)\n", 259 | " collate_fn = partial(event_stream_collate_fn, resolution=(700,), pad_unit=pad_unit, no_time_information=no_time_information)\n", 260 | " test_loader = event_stream_dataloader(\n", 261 | " test_data,\n", 262 | " eval_collate_fn=collate_fn,\n", 263 | " eval_batch_size=per_device_eval_batch_size * world_size,\n", 264 | " rng=rng, \n", 265 | " num_workers=num_workers\n", 266 | " )\n", 267 | " data = Data(\n", 268 | " n_classes=20, num_embeddings=700)\n", 269 | " return test_loader, data" 270 | ], 271 | "outputs": [], 272 | "execution_count": null 273 | }, 274 | { 275 | "cell_type": "code", 276 | "metadata": {}, 277 | "source": [ 278 | "print(\"[*] Loading dataset...\")\n", 279 | "num_devices = jax.local_device_count()\n", 280 | "test_loader, data = create_events_shd_classification_dataset(\n", 281 | " cache_dir=cfg.data_dir,\n", 282 | " seed=cfg.seed,\n", 283 | " world_size=num_devices,\n", 284 | " per_device_eval_batch_size = 1,\n", 285 | " pad_unit=cfg.training.pad_unit \n", 286 | " )" 287 | ], 288 | "outputs": [], 289 | "execution_count": null 290 | }, 291 | { 292 | "cell_type": "code", 293 | "metadata": {}, 294 | "source": [ 295 | "# Load a sample\n", 296 | "batch = next(iter(test_loader))\n", 297 | "inputs, targets, timesteps, lengths = batch" 298 | ], 299 | "outputs": [], 300 | "execution_count": null 301 | }, 302 | { 303 | "cell_type": "markdown", 304 | "metadata": {}, 305 | "source": [ 306 | "### Step 4 : Load model" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "metadata": {}, 312 | "source": [ 313 | "# Set the random seed manually for reproducibility.\n", 314 | "init_key = jax.random.PRNGKey(cfg.seed)" 315 | ], 316 | "outputs": [], 317 | "execution_count": null 318 | }, 319 | { 320 | "cell_type": "code", 321 | "metadata": {}, 322 | "source": [ 323 | "# Model initialisation in flax\n", 324 | "ssm_init_fn = init_S5SSM(**cfg.model.ssm_init)\n", 325 | "model = BatchClassificationModel(\n", 326 | " ssm=ssm_init_fn,\n", 327 | " num_classes=data.n_classes,\n", 328 | " num_embeddings=data.num_embeddings,\n", 329 | " **cfg.model.ssm,\n", 330 | " )" 331 | ], 332 | "outputs": [], 333 | "execution_count": null 334 | }, 335 | { 336 | "cell_type": "code", 337 | "metadata": {}, 338 | "source": [ 339 | "# Visualise model\n", 340 | "print(model.tabulate({\"params\": init_key},\n", 341 | " inputs, timesteps, lengths, False))" 342 | ], 343 | "outputs": [], 344 | "execution_count": null 345 | }, 346 | { 347 | "cell_type": "code", 348 | "metadata": {}, 349 | "source": [ 350 | "checkpoint_dir = os.path.abspath(os.path.join(cfg.checkpoint_dir, 'SHD'))\n", 351 | "training_state = checkpoints.restore_checkpoint(checkpoint_dir, target=None)\n", 352 | "params = training_state['params']\n", 353 | "model_state = training_state['model_state']" 354 | ], 355 | "outputs": [], 356 | "execution_count": null 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "metadata": {}, 361 | "source": [ 362 | "### Step 5 - Model prediction" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "metadata": {}, 368 | "source": "logits = model.apply({'params': params, **model_state}, inputs, timesteps, lengths, False)", 369 | "outputs": [], 370 | "execution_count": null 371 | }, 372 | { 373 | "cell_type": "code", 374 | "metadata": {}, 375 | "source": [ 376 | "print(f\"Predicted label:{jnp.argmax(logits,axis=-1)}\")\n", 377 | "print(f\"Actual label:{targets}\")" 378 | ], 379 | "outputs": [], 380 | "execution_count": null 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": {}, 385 | "source": [ 386 | "### Step 6 - Evaluate model on a batch" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "metadata": {}, 392 | "source": [ 393 | "print(\"[*] Loading dataset...\")\n", 394 | "num_devices = jax.local_device_count()\n", 395 | "test_loader, data = create_events_shd_classification_dataset(\n", 396 | " cache_dir=cfg.data_dir,\n", 397 | " seed=cfg.seed,\n", 398 | " world_size=num_devices,\n", 399 | " per_device_eval_batch_size = cfg.training.per_device_eval_batch_size,\n", 400 | " pad_unit=cfg.training.pad_unit,\n", 401 | " #no_time_information = cfg.training.no_time_information\n", 402 | " \n", 403 | " )" 404 | ], 405 | "outputs": [], 406 | "execution_count": null 407 | }, 408 | { 409 | "cell_type": "code", 410 | "metadata": {}, 411 | "source": [ 412 | "# Load a batch\n", 413 | "batch = next(iter(test_loader))\n", 414 | "inputs, targets, timesteps, lengths = batch\n", 415 | "logits = model.apply({'params': params, **model_state},inputs, timesteps, lengths,False)" 416 | ], 417 | "outputs": [], 418 | "execution_count": null 419 | }, 420 | { 421 | "cell_type": "code", 422 | "metadata": {}, 423 | "source": [ 424 | "# Plot the confusion matrix\n", 425 | "cm = confusion_matrix(jnp.argmax(logits,axis=1), targets)\n", 426 | "sns.heatmap(cm, annot=True,fmt='d', cmap='YlGnBu')\n", 427 | "plt.ylabel('Prediction',fontsize=12)\n", 428 | "plt.xlabel('Actual',fontsize=12)\n", 429 | "plt.title('Confusion Matrix',fontsize=16)\n", 430 | "plt.show()" 431 | ], 432 | "outputs": [], 433 | "execution_count": null 434 | }, 435 | { 436 | "cell_type": "code", 437 | "metadata": {}, 438 | "source": [ 439 | "print(f\"Accuracy of the model: {(jnp.argmax(logits,axis=1)==targets).mean()}\")" 440 | ], 441 | "outputs": [], 442 | "execution_count": null 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "metadata": {}, 447 | "source": [ 448 | "# Task 2 - Spiking Speech Commands \n", 449 | "\n", 450 | "The Spiking Speech Commands is based on the Speech Commands release by Google which consists of utterances recorded from a larger number of speakers under less controlled conditions. It contains 35 word categories from a larger number of speakers.\n", 451 | "\n", 452 | "Ref : https://arxiv.org/pdf/1910.07407v3" 453 | ] 454 | }, 455 | { 456 | "cell_type": "markdown", 457 | "metadata": {}, 458 | "source": [ 459 | "### Step 1 : Load configuration" 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "metadata": {}, 465 | "source": [ 466 | "# Load configurations\n", 467 | "with initialize(version_base=None, config_path=\"configs\"):\n", 468 | " cfg = compose(config_name=\"base.yaml\",overrides=[\"task=spiking-speech-commands\"])" 469 | ], 470 | "outputs": [], 471 | "execution_count": null 472 | }, 473 | { 474 | "cell_type": "code", 475 | "metadata": {}, 476 | "source": [ 477 | "# See the model config:\n", 478 | "print(om.to_yaml(cfg.model))" 479 | ], 480 | "outputs": [], 481 | "execution_count": null 482 | }, 483 | { 484 | "cell_type": "markdown", 485 | "metadata": {}, 486 | "source": [ 487 | "### Step 2 : Visualise data" 488 | ] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "metadata": {}, 493 | "source": [ 494 | "data = tonic.datasets.SSC(cfg.data_dir, split='test')\n", 495 | "audio_events, label = data[0]\n", 496 | "tonic.utils.plot_event_grid(audio_events)" 497 | ], 498 | "outputs": [], 499 | "execution_count": null 500 | }, 501 | { 502 | "cell_type": "markdown", 503 | "metadata": {}, 504 | "source": [ 505 | "### Step 3: Load single data sample for inference" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "metadata": {}, 511 | "source": [ 512 | "def create_events_ssc_classification_dataset(\n", 513 | " cache_dir: Union[str, Path] = DEFAULT_CACHE_DIR_ROOT,\n", 514 | " per_device_eval_batch_size: int = 64,\n", 515 | " world_size: int = 1,\n", 516 | " num_workers: int = 0,\n", 517 | " seed: int = 42,\n", 518 | " pad_unit: int = 8192,\n", 519 | " no_time_information: bool = False,\n", 520 | " **kwargs\n", 521 | ") -> Tuple[DataLoader, DataLoader, DataLoader, Data]:\n", 522 | " \"\"\"\n", 523 | " creates a view of the spiking speech commands dataset\n", 524 | "\n", 525 | " :param cache_dir:\t\t(str):\t\twhere to store the dataset\n", 526 | " :param bsz:\t\t\t\t(int):\t\tBatch size.\n", 527 | " :param seed:\t\t\t(int)\t\tSeed for shuffling data.\n", 528 | " \"\"\"\n", 529 | " print(\"[*] Generating Spiking Speech Commands Classification Dataset\")\n", 530 | "\n", 531 | " if seed is not None:\n", 532 | " rng = torch.Generator()\n", 533 | " rng.manual_seed(seed)\n", 534 | " else:\n", 535 | " rng = None\n", 536 | "\n", 537 | " test_data = tonic.datasets.SSC(save_to=cache_dir, split='test')\n", 538 | " collate_fn = partial(event_stream_collate_fn, resolution=(700,), pad_unit=pad_unit, no_time_information=no_time_information)\n", 539 | " test_loader = event_stream_dataloader(\n", 540 | " test_data,\n", 541 | " eval_collate_fn=collate_fn,\n", 542 | " eval_batch_size=per_device_eval_batch_size * world_size,\n", 543 | " rng=rng, \n", 544 | " num_workers=num_workers,\n", 545 | " )\n", 546 | "\n", 547 | " data = Data(\n", 548 | " n_classes=35, num_embeddings=700\n", 549 | " )\n", 550 | " return test_loader, data\n" 551 | ], 552 | "outputs": [], 553 | "execution_count": null 554 | }, 555 | { 556 | "cell_type": "code", 557 | "metadata": {}, 558 | "source": [ 559 | "print(\"[*] Loading dataset...\")\n", 560 | "num_devices = jax.local_device_count()\n", 561 | "test_loader, data = create_events_ssc_classification_dataset(\n", 562 | " cache_dir=cfg.data_dir,\n", 563 | " seed=cfg.seed,\n", 564 | " world_size=num_devices,\n", 565 | " per_device_eval_batch_size = 1,\n", 566 | " pad_unit=cfg.training.pad_unit \n", 567 | " )" 568 | ], 569 | "outputs": [], 570 | "execution_count": null 571 | }, 572 | { 573 | "cell_type": "code", 574 | "metadata": {}, 575 | "source": [ 576 | "# Load a sample\n", 577 | "batch = next(iter(test_loader))\n", 578 | "inputs, targets, timesteps, lengths = batch" 579 | ], 580 | "outputs": [], 581 | "execution_count": null 582 | }, 583 | { 584 | "cell_type": "markdown", 585 | "metadata": {}, 586 | "source": [ 587 | "### Step 4 : Load model" 588 | ] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "metadata": {}, 593 | "source": [ 594 | "# Set the random seed manually for reproducibility.\n", 595 | "init_key = jax.random.PRNGKey(cfg.seed)" 596 | ], 597 | "outputs": [], 598 | "execution_count": null 599 | }, 600 | { 601 | "cell_type": "code", 602 | "metadata": {}, 603 | "source": [ 604 | "ssm_init_fn = init_S5SSM(**cfg.model.ssm_init)\n", 605 | "model = BatchClassificationModel(\n", 606 | " ssm=ssm_init_fn,\n", 607 | " num_classes=data.n_classes,\n", 608 | " num_embeddings=data.num_embeddings,\n", 609 | " **cfg.model.ssm,\n", 610 | " )" 611 | ], 612 | "outputs": [], 613 | "execution_count": null 614 | }, 615 | { 616 | "cell_type": "code", 617 | "metadata": {}, 618 | "source": [ 619 | "print(model.tabulate({\"params\": init_key},\n", 620 | " inputs, timesteps, lengths, False))" 621 | ], 622 | "outputs": [], 623 | "execution_count": null 624 | }, 625 | { 626 | "cell_type": "code", 627 | "metadata": {}, 628 | "source": [ 629 | "# load model parameters from checkpoint\n", 630 | "checkpoint_dir = os.path.abspath(os.path.join(cfg.checkpoint_dir, 'SSC'))\n", 631 | "training_state = checkpoints.restore_checkpoint(checkpoint_dir, target=None)\n", 632 | "params = training_state['params']\n", 633 | "model_state = training_state['model_state']" 634 | ], 635 | "outputs": [], 636 | "execution_count": null 637 | }, 638 | { 639 | "cell_type": "markdown", 640 | "metadata": {}, 641 | "source": [ 642 | "### Step 5 - Model prediction" 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "metadata": {}, 648 | "source": "logits = model.apply({'params': params, **model_state},inputs, timesteps, lengths,False)", 649 | "outputs": [], 650 | "execution_count": null 651 | }, 652 | { 653 | "cell_type": "code", 654 | "metadata": {}, 655 | "source": [ 656 | "print(f\"Predicted label:{jnp.argmax(logits,axis=-1)}\")\n", 657 | "print(f\"Actual label:{targets}\")" 658 | ], 659 | "outputs": [], 660 | "execution_count": null 661 | }, 662 | { 663 | "cell_type": "markdown", 664 | "metadata": {}, 665 | "source": [ 666 | "### Step 6 - Evaluate model on single batch" 667 | ] 668 | }, 669 | { 670 | "cell_type": "code", 671 | "metadata": {}, 672 | "source": [ 673 | "print(\"[*] Loading dataset...\")\n", 674 | "num_devices = jax.local_device_count()\n", 675 | "test_loader, data = create_events_ssc_classification_dataset(\n", 676 | " cache_dir=cfg.data_dir,\n", 677 | " seed=cfg.seed,\n", 678 | " world_size=num_devices,\n", 679 | " per_device_eval_batch_size = cfg.training.per_device_eval_batch_size,\n", 680 | " pad_unit=cfg.training.pad_unit,\n", 681 | " #no_time_information = cfg.training.no_time_information\n", 682 | " \n", 683 | " )" 684 | ], 685 | "outputs": [], 686 | "execution_count": null 687 | }, 688 | { 689 | "cell_type": "code", 690 | "metadata": {}, 691 | "source": [ 692 | "# Load a batch\n", 693 | "batch = next(iter(test_loader))\n", 694 | "inputs, targets, timesteps, lengths = batch" 695 | ], 696 | "outputs": [], 697 | "execution_count": null 698 | }, 699 | { 700 | "cell_type": "code", 701 | "metadata": {}, 702 | "source": "logits = model.apply({'params': params, **model_state},inputs, timesteps, lengths,False)", 703 | "outputs": [], 704 | "execution_count": null 705 | }, 706 | { 707 | "cell_type": "code", 708 | "metadata": {}, 709 | "source": [ 710 | "# Plot the confusion matrix\n", 711 | "cm = confusion_matrix(jax.numpy.argmax(logits,axis=-1), targets)\n", 712 | "sns.heatmap(cm, annot=True,fmt='d', cmap='YlGnBu')\n", 713 | "plt.ylabel('Prediction',fontsize=12)\n", 714 | "plt.xlabel('Actual',fontsize=12)\n", 715 | "plt.title('Confusion Matrix',fontsize=16)\n", 716 | "plt.show()" 717 | ], 718 | "outputs": [], 719 | "execution_count": null 720 | }, 721 | { 722 | "cell_type": "code", 723 | "metadata": {}, 724 | "source": "print(f\"Accuracy of the model: {(jnp.argmax(logits,axis=-1)==targets).mean()}\")", 725 | "outputs": [], 726 | "execution_count": null 727 | }, 728 | { 729 | "cell_type": "markdown", 730 | "metadata": {}, 731 | "source": [ 732 | "# Task 3 - DVS Gesture " 733 | ] 734 | }, 735 | { 736 | "cell_type": "markdown", 737 | "metadata": {}, 738 | "source": [ 739 | "## Task Description\n", 740 | "\n", 741 | "It is the first gesture recognition system implemented end-to-end on event-based hardware. The dataset comprises of 11 hand gesture categories from 29 subjects under 3 illumination conditions.\n", 742 | "\n", 743 | "Ref : https://ieeexplore.ieee.org/document/8100264\n", 744 | "\n", 745 | "### Excercise\n", 746 | "\n", 747 | "Similar to SHD and SSC, implement inference steps for DVS Gesture data." 748 | ] 749 | }, 750 | { 751 | "cell_type": "markdown", 752 | "metadata": {}, 753 | "source": [ 754 | "### Step 1 : Load configuration" 755 | ] 756 | }, 757 | { 758 | "cell_type": "code", 759 | "metadata": {}, 760 | "source": [ 761 | "# Load configurations\n", 762 | "with initialize(version_base=None, config_path=\"configs\"):\n", 763 | " cfg = compose(config_name=\"base.yaml\",overrides=[\"task=dvs-gesture\"])" 764 | ], 765 | "outputs": [], 766 | "execution_count": null 767 | }, 768 | { 769 | "cell_type": "code", 770 | "metadata": {}, 771 | "source": [ 772 | "# model config:\n", 773 | "print(om.to_yaml(cfg.model))" 774 | ], 775 | "outputs": [], 776 | "execution_count": null 777 | }, 778 | { 779 | "cell_type": "markdown", 780 | "metadata": {}, 781 | "source": [ 782 | "### Step 2 : Visualise Data" 783 | ] 784 | }, 785 | { 786 | "cell_type": "code", 787 | "metadata": {}, 788 | "source": [ 789 | "from IPython.display import HTML\n", 790 | "import warnings\n", 791 | "\n", 792 | "warnings.filterwarnings(\"ignore\")\n", 793 | "\n", 794 | "#warnings.filterwarnings( \"ignore\", module = \"matplotlib\\..*\" )\n", 795 | "\n", 796 | "data = tonic.datasets.DVSGesture(cfg.data_dir, train=False)\n", 797 | "events, label = data[0]\n", 798 | "\n", 799 | "transform = tonic.transforms.Compose(\n", 800 | " [\n", 801 | " tonic.transforms.TimeJitter(std=100, clip_negative=False),\n", 802 | " tonic.transforms.ToFrame(\n", 803 | " sensor_size=data.sensor_size,\n", 804 | " time_window=10000,\n", 805 | " ),\n", 806 | " ]\n", 807 | ")\n", 808 | "\n", 809 | "frames = transform(events)\n", 810 | "HTML(tonic.utils.plot_animation((frames* 255).astype(np.uint8)).to_html5_video())" 811 | ], 812 | "outputs": [], 813 | "execution_count": null 814 | }, 815 | { 816 | "cell_type": "markdown", 817 | "metadata": {}, 818 | "source": [ 819 | "### Step 3: Load single inference sample" 820 | ] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "metadata": {}, 825 | "source": [ 826 | "from event_ssm.transform import Identity" 827 | ], 828 | "outputs": [], 829 | "execution_count": null 830 | }, 831 | { 832 | "cell_type": "code", 833 | "metadata": {}, 834 | "source": [ 835 | "def create_events_dvs_gesture_classification_dataset(\n", 836 | " cache_dir: Union[str, Path] = DEFAULT_CACHE_DIR_ROOT,\n", 837 | " per_device_eval_batch_size: int = 64,\n", 838 | " world_size: int = 1,\n", 839 | " num_workers: int = 0,\n", 840 | " seed: int = 42,\n", 841 | " pad_unit: int = 2 ** 19,\n", 842 | " downsampling: int=1,\n", 843 | " **kwargs\n", 844 | ") -> Tuple[DataLoader, Data]:\n", 845 | " \"\"\"\n", 846 | " creates a view of the DVS Gesture dataset\n", 847 | "\n", 848 | " :param cache_dir:\t\t(str):\t\twhere to store the dataset\n", 849 | " :param bsz:\t\t\t\t(int):\t\tBatch size.\n", 850 | " :param seed:\t\t\t(int)\t\tSeed for shuffling data.\n", 851 | " \"\"\"\n", 852 | " print(\"[*] Generating DVS Gesture Classification Dataset\")\n", 853 | "\n", 854 | " if seed is not None:\n", 855 | " rng = torch.Generator()\n", 856 | " rng.manual_seed(seed)\n", 857 | " else:\n", 858 | " rng = None\n", 859 | "\n", 860 | " orig_sensor_size = (128, 128, 2)\n", 861 | " new_sensor_size = (128 // downsampling, 128 // downsampling, 2)\n", 862 | " test_transforms = tonic.transforms.Compose([\n", 863 | " tonic.transforms.Downsample(sensor_size=orig_sensor_size, target_size=new_sensor_size[:2]) if downsampling > 1 else Identity(),\n", 864 | " ])\n", 865 | "\n", 866 | " TestData = partial(tonic.datasets.DVSGesture, save_to=cache_dir, train=False)\n", 867 | " test_data = TestData(transform=test_transforms)\n", 868 | "\n", 869 | " # define collate function\n", 870 | " eval_collate_fn = partial(\n", 871 | " event_stream_collate_fn,\n", 872 | " resolution=new_sensor_size[:2],\n", 873 | " pad_unit=pad_unit,\n", 874 | " )\n", 875 | " test_loader = event_stream_dataloader(\n", 876 | " test_data,\n", 877 | " eval_collate_fn=eval_collate_fn,\n", 878 | " eval_batch_size=per_device_eval_batch_size * world_size,\n", 879 | " rng=rng, \n", 880 | " num_workers=num_workers\n", 881 | " )\n", 882 | "\n", 883 | " data = Data(\n", 884 | " n_classes=11, num_embeddings=np.prod(new_sensor_size)\n", 885 | " )\n", 886 | " return test_loader, data" 887 | ], 888 | "outputs": [], 889 | "execution_count": null 890 | }, 891 | { 892 | "cell_type": "code", 893 | "metadata": {}, 894 | "source": [ 895 | "num_devices = jax.local_device_count()\n", 896 | " # Create dataset...\n", 897 | "test_loader, data = create_events_dvs_gesture_classification_dataset(\n", 898 | " cache_dir=cfg.data_dir,\n", 899 | " seed=cfg.seed,\n", 900 | " world_size=num_devices,\n", 901 | " per_device_eval_batch_size = 1,\n", 902 | " pad_unit=cfg.training.pad_unit, \n", 903 | " )" 904 | ], 905 | "outputs": [], 906 | "execution_count": null 907 | }, 908 | { 909 | "cell_type": "code", 910 | "metadata": {}, 911 | "source": [ 912 | "# Load a sample\n", 913 | "batch = next(iter(test_loader))\n", 914 | "inputs, targets, timesteps, lengths = batch" 915 | ], 916 | "outputs": [], 917 | "execution_count": null 918 | }, 919 | { 920 | "cell_type": "markdown", 921 | "metadata": {}, 922 | "source": [ 923 | "### Step 4 : Load model" 924 | ] 925 | }, 926 | { 927 | "cell_type": "code", 928 | "metadata": {}, 929 | "source": [ 930 | "# Set the random key for the task\n", 931 | "init_key = jax.random.PRNGKey(cfg.seed)" 932 | ], 933 | "outputs": [], 934 | "execution_count": null 935 | }, 936 | { 937 | "cell_type": "code", 938 | "metadata": {}, 939 | "source": [ 940 | "print(\"[*] Creating model...\")\n", 941 | "ssm_init_fn = init_S5SSM(**cfg.model.ssm_init)\n", 942 | "model = BatchClassificationModel(\n", 943 | " ssm=ssm_init_fn,\n", 944 | " num_classes=data.n_classes,\n", 945 | " num_embeddings=data.num_embeddings,\n", 946 | " **cfg.model.ssm,\n", 947 | " )" 948 | ], 949 | "outputs": [], 950 | "execution_count": null 951 | }, 952 | { 953 | "cell_type": "code", 954 | "metadata": {}, 955 | "source": [ 956 | "# visualise model\n", 957 | "print(model.tabulate({\"params\": init_key},\n", 958 | " inputs, timesteps, lengths, False))" 959 | ], 960 | "outputs": [], 961 | "execution_count": null 962 | }, 963 | { 964 | "cell_type": "code", 965 | "metadata": {}, 966 | "source": [ 967 | "# load model parameters from checkpoint\n", 968 | "checkpoint_dir = os.path.abspath(os.path.join(cfg.checkpoint_dir, 'DVS'))\n", 969 | "training_state = checkpoints.restore_checkpoint(checkpoint_dir, target=None)\n", 970 | "params = training_state['params']\n", 971 | "model_state = training_state['model_state']" 972 | ], 973 | "outputs": [], 974 | "execution_count": null 975 | }, 976 | { 977 | "cell_type": "markdown", 978 | "metadata": {}, 979 | "source": [ 980 | "### Step 5 - Model prediction" 981 | ] 982 | }, 983 | { 984 | "cell_type": "code", 985 | "metadata": {}, 986 | "source": "logits = model.apply({'params': params, **model_state}, inputs, timesteps, lengths, False)", 987 | "outputs": [], 988 | "execution_count": null 989 | }, 990 | { 991 | "cell_type": "code", 992 | "metadata": {}, 993 | "source": [ 994 | "print(f\"Predicted label:{jnp.argmax(logits,axis=-1)}\")\n", 995 | "print(f\"Actual label:{targets}\")" 996 | ], 997 | "outputs": [], 998 | "execution_count": null 999 | }, 1000 | { 1001 | "cell_type": "markdown", 1002 | "metadata": {}, 1003 | "source": [ 1004 | "### Step 6 - Evaluate model on single batch" 1005 | ] 1006 | }, 1007 | { 1008 | "cell_type": "code", 1009 | "metadata": {}, 1010 | "source": [ 1011 | "num_devices = jax.local_device_count()\n", 1012 | " # Create dataset...\n", 1013 | "test_loader, data = create_events_dvs_gesture_classification_dataset(\n", 1014 | " cache_dir=cfg.data_dir,\n", 1015 | " seed=cfg.seed,\n", 1016 | " world_size=num_devices,\n", 1017 | " per_device_eval_batch_size = cfg.training.per_device_eval_batch_size,\n", 1018 | " pad_unit=cfg.training.pad_unit,\n", 1019 | " #no_time_information = cfg.training.no_time_information\n", 1020 | " )" 1021 | ], 1022 | "outputs": [], 1023 | "execution_count": null 1024 | }, 1025 | { 1026 | "cell_type": "code", 1027 | "metadata": {}, 1028 | "source": [ 1029 | "# Load a batch\n", 1030 | "batch = next(iter(test_loader))\n", 1031 | "inputs, targets, timesteps, lengths = batch" 1032 | ], 1033 | "outputs": [], 1034 | "execution_count": null 1035 | }, 1036 | { 1037 | "cell_type": "code", 1038 | "metadata": { 1039 | "jupyter": { 1040 | "is_executing": true 1041 | } 1042 | }, 1043 | "source": "logits = model.apply({'params': params, **model_state}, inputs, timesteps, lengths, False)", 1044 | "outputs": [], 1045 | "execution_count": null 1046 | }, 1047 | { 1048 | "cell_type": "code", 1049 | "metadata": {}, 1050 | "source": [ 1051 | "# Plot the confusion matrix\n", 1052 | "cm = confusion_matrix(jax.numpy.argmax(logits, axis=-1), targets)\n", 1053 | "sns.heatmap(cm, annot=True, fmt='d', cmap='YlGnBu')\n", 1054 | "plt.ylabel('Prediction', fontsize=12)\n", 1055 | "plt.xlabel('Actual', fontsize=12)\n", 1056 | "plt.title('Confusion Matrix', fontsize=16)\n", 1057 | "plt.show()" 1058 | ], 1059 | "outputs": [], 1060 | "execution_count": null 1061 | }, 1062 | { 1063 | "cell_type": "code", 1064 | "metadata": {}, 1065 | "source": "print(f\"Accuracy of the model: {(jnp.argmax(logits, axis=-1) == targets).mean()}\")", 1066 | "outputs": [], 1067 | "execution_count": null 1068 | }, 1069 | { 1070 | "metadata": {}, 1071 | "cell_type": "code", 1072 | "source": "", 1073 | "outputs": [], 1074 | "execution_count": null 1075 | } 1076 | ], 1077 | "metadata": { 1078 | "kernelspec": { 1079 | "display_name": "blocksparse", 1080 | "language": "python", 1081 | "name": "python3" 1082 | }, 1083 | "language_info": { 1084 | "codemirror_mode": { 1085 | "name": "ipython", 1086 | "version": 3 1087 | }, 1088 | "file_extension": ".py", 1089 | "mimetype": "text/x-python", 1090 | "name": "python", 1091 | "nbconvert_exporter": "python", 1092 | "pygments_lexer": "ipython3", 1093 | "version": "3.10.4" 1094 | } 1095 | }, 1096 | "nbformat": 4, 1097 | "nbformat_minor": 2 1098 | } 1099 | -------------------------------------------------------------------------------- /tutorial_online_inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "metadata": {}, 5 | "cell_type": "markdown", 6 | "source": [ 7 | "# Online Inference Tutorial\n", 8 | "In this tutorial, we will implement online inference with event-based state-space models.\n", 9 | "Online inference is the process of classifying events as they arrive in real-time.\n", 10 | "For many edge systems, the batch size is 1, and the model has to meet a specific throughput of events per second.\n", 11 | "Here, you will test if your CPU is able to run real-time classification with EventSSM.\n", 12 | "\n", 13 | "The tutorial requires basic familiarity with JAX." 14 | ], 15 | "id": "b99721b9d6b26c10" 16 | }, 17 | { 18 | "metadata": { 19 | "ExecuteTime": { 20 | "end_time": "2024-05-27T09:22:32.658921Z", 21 | "start_time": "2024-05-27T09:22:32.654126Z" 22 | } 23 | }, 24 | "cell_type": "code", 25 | "source": [ 26 | "from hydra import initialize, compose\n", 27 | "from omegaconf import OmegaConf as om\n", 28 | "\n", 29 | "import jax\n", 30 | "import jax.numpy as jnp\n", 31 | "\n", 32 | "from event_ssm.ssm import init_S5SSM\n", 33 | "from event_ssm.seq_model import ClassificationModel" 34 | ], 35 | "id": "bc0a9044321d654d", 36 | "outputs": [], 37 | "execution_count": 24 38 | }, 39 | { 40 | "metadata": {}, 41 | "cell_type": "markdown", 42 | "source": "## Step 1: Load the model", 43 | "id": "d8b261a76014fbc7" 44 | }, 45 | { 46 | "metadata": { 47 | "ExecuteTime": { 48 | "end_time": "2024-05-27T09:22:33.679045Z", 49 | "start_time": "2024-05-27T09:22:33.561733Z" 50 | } 51 | }, 52 | "cell_type": "code", 53 | "source": [ 54 | "# Load configurations\n", 55 | "with initialize(version_base=None, config_path=\"configs\"):\n", 56 | " cfg = compose(config_name=\"base.yaml\", overrides=[\"model=dvs/small\"])" 57 | ], 58 | "id": "7efb7b5428f7472", 59 | "outputs": [], 60 | "execution_count": 25 61 | }, 62 | { 63 | "metadata": { 64 | "ExecuteTime": { 65 | "end_time": "2024-05-27T09:22:33.771341Z", 66 | "start_time": "2024-05-27T09:22:33.766065Z" 67 | } 68 | }, 69 | "cell_type": "code", 70 | "source": [ 71 | "# Print the configuration\n", 72 | "print(om.to_yaml(cfg.model))" 73 | ], 74 | "id": "16eb6e254f8090cd", 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "ssm_init:\n", 81 | " C_init: lecun_normal\n", 82 | " dt_min: 0.001\n", 83 | " dt_max: 0.1\n", 84 | " conj_sym: false\n", 85 | " clip_eigs: true\n", 86 | "ssm:\n", 87 | " discretization: async\n", 88 | " d_model: 128\n", 89 | " d_ssm: 128\n", 90 | " ssm_block_size: 16\n", 91 | " num_stages: 2\n", 92 | " num_layers_per_stage: 3\n", 93 | " dropout: 0.25\n", 94 | " classification_mode: timepool\n", 95 | " prenorm: true\n", 96 | " batchnorm: false\n", 97 | " bn_momentum: 0.95\n", 98 | " pooling_stride: 16\n", 99 | " pooling_mode: timepool\n", 100 | " state_expansion_factor: 2\n", 101 | "\n" 102 | ] 103 | } 104 | ], 105 | "execution_count": 26 106 | }, 107 | { 108 | "metadata": { 109 | "ExecuteTime": { 110 | "end_time": "2024-05-27T09:22:34.290002Z", 111 | "start_time": "2024-05-27T09:22:34.282856Z" 112 | } 113 | }, 114 | "cell_type": "code", 115 | "source": [ 116 | "# Set the random seed manually for reproducibility.\n", 117 | "key = jax.random.PRNGKey(cfg.seed)\n", 118 | "init_key, data_key = jax.random.split(key)" 119 | ], 120 | "id": "9806959c6627a4d5", 121 | "outputs": [], 122 | "execution_count": 27 123 | }, 124 | { 125 | "metadata": { 126 | "ExecuteTime": { 127 | "end_time": "2024-05-27T09:22:34.904836Z", 128 | "start_time": "2024-05-27T09:22:34.897859Z" 129 | } 130 | }, 131 | "cell_type": "code", 132 | "source": [ 133 | "# Model initialisation in flax\n", 134 | "ssm_init_fn = init_S5SSM(**cfg.model.ssm_init)\n", 135 | "\n", 136 | "# number of classes (dummy)\n", 137 | "classes = 10\n", 138 | "\n", 139 | "# number of tokens for a DVS sensor of size 128x128\n", 140 | "num_tokens = 128 * 128 * 2\n", 141 | "model = ClassificationModel(\n", 142 | " ssm=ssm_init_fn,\n", 143 | " num_classes=10,\n", 144 | " num_embeddings=num_tokens,\n", 145 | " **cfg.model.ssm,\n", 146 | " )" 147 | ], 148 | "id": "b936f3fdd1538bfe", 149 | "outputs": [], 150 | "execution_count": 28 151 | }, 152 | { 153 | "metadata": {}, 154 | "cell_type": "markdown", 155 | "source": [ 156 | "EventSSM subsamples sequences in multiple stages to reduce the computational cost.\n", 157 | "Let's investigate the total subsampling" 158 | ], 159 | "id": "accb046df2d07e7" 160 | }, 161 | { 162 | "metadata": { 163 | "ExecuteTime": { 164 | "end_time": "2024-05-27T09:56:14.174709Z", 165 | "start_time": "2024-05-27T09:56:14.161702Z" 166 | } 167 | }, 168 | "cell_type": "code", 169 | "source": [ 170 | "total_subsampling = cfg.model.ssm.pooling_stride ** cfg.model.ssm.num_stages\n", 171 | "print(f\"Total subsampling: {total_subsampling}\")" 172 | ], 173 | "id": "3ed763820fe9f204", 174 | "outputs": [ 175 | { 176 | "name": "stdout", 177 | "output_type": "stream", 178 | "text": [ 179 | "Total subsampling: 256\n" 180 | ] 181 | } 182 | ], 183 | "execution_count": 35 184 | }, 185 | { 186 | "metadata": { 187 | "ExecuteTime": { 188 | "end_time": "2024-05-27T09:56:42.653733Z", 189 | "start_time": "2024-05-27T09:56:38.056333Z" 190 | } 191 | }, 192 | "cell_type": "code", 193 | "source": [ 194 | "# initialize model parameters\n", 195 | "x = jnp.zeros(total_subsampling, dtype=jnp.int32)\n", 196 | "t = jnp.ones(total_subsampling)\n", 197 | "variables = model.init(\n", 198 | " {\"params\": init_key},\n", 199 | " x, t, total_subsampling, False\n", 200 | " )" 201 | ], 202 | "id": "e18fbb811f6c46e0", 203 | "outputs": [ 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "SSM: 128 -> 128 -> 128 (stride 16 with pooling mode timepool)\n", 209 | "SSM: 128 -> 128 -> 128\n", 210 | "SSM: 128 -> 128 -> 128\n", 211 | "SSM: 128 -> 256 -> 256 (stride 16 with pooling mode timepool)\n", 212 | "SSM: 256 -> 256 -> 256\n", 213 | "SSM: 256 -> 256 -> 256\n" 214 | ] 215 | } 216 | ], 217 | "execution_count": 36 218 | }, 219 | { 220 | "metadata": {}, 221 | "cell_type": "markdown", 222 | "source": [ 223 | "## Step 2: Run the model on random data\n", 224 | "Generate a random list of integer tokens, jit compile the model and classify online." 225 | ], 226 | "id": "8ed847f8098b7f53" 227 | }, 228 | { 229 | "metadata": { 230 | "ExecuteTime": { 231 | "end_time": "2024-05-27T10:14:54.375839Z", 232 | "start_time": "2024-05-27T10:14:54.360101Z" 233 | } 234 | }, 235 | "cell_type": "code", 236 | "source": [ 237 | "# Generate random data\n", 238 | "sequence_length = 2 ** 18\n", 239 | "tokens = jax.random.randint(data_key, shape=(sequence_length,), minval=0, maxval=num_tokens)\n", 240 | "timesteps = jnp.ones(sequence_length)\n", 241 | "print(\"Sequence length:\", sequence_length)" 242 | ], 243 | "id": "9b32e55bfaf178e9", 244 | "outputs": [ 245 | { 246 | "name": "stdout", 247 | "output_type": "stream", 248 | "text": [ 249 | "Sequence length: 262144\n" 250 | ] 251 | } 252 | ], 253 | "execution_count": 63 254 | }, 255 | { 256 | "metadata": { 257 | "ExecuteTime": { 258 | "end_time": "2024-05-27T10:15:07.170346Z", 259 | "start_time": "2024-05-27T10:14:55.732901Z" 260 | } 261 | }, 262 | "cell_type": "code", 263 | "source": [ 264 | "# jit compile the model\n", 265 | "from functools import partial\n", 266 | "jit_apply = jax.jit(partial(model.apply, length=total_subsampling, train=False))\n", 267 | "jit_apply(variables, x[:total_subsampling], t[:total_subsampling])" 268 | ], 269 | "id": "8f49cd496d6ef30d", 270 | "outputs": [ 271 | { 272 | "name": "stdout", 273 | "output_type": "stream", 274 | "text": [ 275 | "SSM: 128 -> 128 -> 128 (stride 16 with pooling mode timepool)\n", 276 | "SSM: 128 -> 128 -> 128\n", 277 | "SSM: 128 -> 128 -> 128\n", 278 | "SSM: 128 -> 256 -> 256 (stride 16 with pooling mode timepool)\n", 279 | "SSM: 256 -> 256 -> 256\n", 280 | "SSM: 256 -> 256 -> 256\n" 281 | ] 282 | }, 283 | { 284 | "data": { 285 | "text/plain": [ 286 | "Array([-0.12317943, -0.17902763, -0.26315966, 0.5992651 , 0.7048361 ,\n", 287 | " 1.2036127 , 0.00121723, 0.41398254, 0.26262668, 0.18357195], dtype=float32)" 288 | ] 289 | }, 290 | "execution_count": 64, 291 | "metadata": {}, 292 | "output_type": "execute_result" 293 | } 294 | ], 295 | "execution_count": 64 296 | }, 297 | { 298 | "metadata": { 299 | "ExecuteTime": { 300 | "end_time": "2024-05-27T10:15:11.763566Z", 301 | "start_time": "2024-05-27T10:15:08.166525Z" 302 | } 303 | }, 304 | "cell_type": "code", 305 | "source": [ 306 | "# loop through the model\n", 307 | "from tqdm import tqdm\n", 308 | "from time import time\n", 309 | "print(f\"Looping through {sequence_length} events with total_subsampling={total_subsampling} --> {sequence_length // total_subsampling} iterations\")\n", 310 | "start = time()\n", 311 | "for i in tqdm(range(0, sequence_length, total_subsampling)):\n", 312 | " x = tokens[i:i + total_subsampling]\n", 313 | " t = timesteps[i:i + total_subsampling]\n", 314 | " logits = jit_apply(variables, x, t).block_until_ready()\n", 315 | "end = time()\n", 316 | "print(f\"Time taken: {end - start:.2f}s\")\n", 317 | "print(f\"Events per second: {sequence_length / (end - start):.2f}\")" 318 | ], 319 | "id": "55a885c77a44e8eb", 320 | "outputs": [ 321 | { 322 | "name": "stdout", 323 | "output_type": "stream", 324 | "text": [ 325 | "Looping through 262144 events with total_subsampling=256 --> 1024 iterations\n" 326 | ] 327 | }, 328 | { 329 | "name": "stderr", 330 | "output_type": "stream", 331 | "text": [ 332 | "100%|██████████| 1024/1024 [00:03<00:00, 285.19it/s]" 333 | ] 334 | }, 335 | { 336 | "name": "stdout", 337 | "output_type": "stream", 338 | "text": [ 339 | "Time taken: 3.59s\n", 340 | "Events per second: 72962.94\n" 341 | ] 342 | }, 343 | { 344 | "name": "stderr", 345 | "output_type": "stream", 346 | "text": [ 347 | "\n" 348 | ] 349 | } 350 | ], 351 | "execution_count": 65 352 | }, 353 | { 354 | "metadata": {}, 355 | "cell_type": "markdown", 356 | "source": [ 357 | "## Step 3: Optimize the inference speed\n", 358 | "We suggest to use [jax.lax.scan](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) instead of the for loop to further speed up the inference." 359 | ], 360 | "id": "541f0afde67081f8" 361 | }, 362 | { 363 | "metadata": { 364 | "ExecuteTime": { 365 | "end_time": "2024-05-27T10:15:27.619686Z", 366 | "start_time": "2024-05-27T10:15:14.529552Z" 367 | } 368 | }, 369 | "cell_type": "code", 370 | "source": [ 371 | "def step(carry, inputs):\n", 372 | " x, t = inputs\n", 373 | " logits = model.apply(variables, x, t, total_subsampling, False)\n", 374 | " return None, logits\n", 375 | "tokens = tokens.reshape(-1, total_subsampling)\n", 376 | "timesteps = timesteps.reshape(-1, total_subsampling)\n", 377 | "\n", 378 | "# run the scan: first jit-compiles and then iterates\n", 379 | "logits = jax.lax.scan(step, init=None, xs=(tokens, timesteps))" 380 | ], 381 | "id": "1318e7467cbb3b3f", 382 | "outputs": [ 383 | { 384 | "name": "stdout", 385 | "output_type": "stream", 386 | "text": [ 387 | "SSM: 128 -> 128 -> 128 (stride 16 with pooling mode timepool)\n", 388 | "SSM: 128 -> 128 -> 128\n", 389 | "SSM: 128 -> 128 -> 128\n", 390 | "SSM: 128 -> 256 -> 256 (stride 16 with pooling mode timepool)\n", 391 | "SSM: 256 -> 256 -> 256\n", 392 | "SSM: 256 -> 256 -> 256\n" 393 | ] 394 | } 395 | ], 396 | "execution_count": 66 397 | }, 398 | { 399 | "metadata": { 400 | "ExecuteTime": { 401 | "end_time": "2024-05-27T10:15:49.444621Z", 402 | "start_time": "2024-05-27T10:15:46.788818Z" 403 | } 404 | }, 405 | "cell_type": "code", 406 | "source": [ 407 | "# measure run-time\n", 408 | "start = time()\n", 409 | "_, logits = jax.block_until_ready(jax.lax.scan(step, init=None, xs=(tokens, timesteps)))\n", 410 | "end = time()\n", 411 | "print(f\"Time taken: {end - start:.2f}s\")\n", 412 | "print(f\"Events per second: {sequence_length / (end - start):.2f}\")" 413 | ], 414 | "id": "aa170aadad84036d", 415 | "outputs": [ 416 | { 417 | "name": "stdout", 418 | "output_type": "stream", 419 | "text": [ 420 | "Time taken: 2.65s\n", 421 | "Events per second: 99018.86\n" 422 | ] 423 | } 424 | ], 425 | "execution_count": 68 426 | }, 427 | { 428 | "metadata": { 429 | "ExecuteTime": { 430 | "end_time": "2024-05-27T10:15:53.224299Z", 431 | "start_time": "2024-05-27T10:15:53.220810Z" 432 | } 433 | }, 434 | "cell_type": "code", 435 | "source": "logits.shape", 436 | "id": "718dffb170c2df1c", 437 | "outputs": [ 438 | { 439 | "data": { 440 | "text/plain": [ 441 | "(1024, 10)" 442 | ] 443 | }, 444 | "execution_count": 69, 445 | "metadata": {}, 446 | "output_type": "execute_result" 447 | } 448 | ], 449 | "execution_count": 69 450 | }, 451 | { 452 | "metadata": {}, 453 | "cell_type": "markdown", 454 | "source": [ 455 | "## Step 4: Run inference on the DVS128 Gestures dataset\n", 456 | "Follow the steps in the `tutorial_inference.ipynb` to run inference on the DVS128 Gestures dataset with a pretrained model.\n", 457 | "Plot the confidence of the model in the correct class over time" 458 | ], 459 | "id": "bcaba7dc4697605d" 460 | }, 461 | { 462 | "metadata": {}, 463 | "cell_type": "code", 464 | "outputs": [], 465 | "execution_count": null, 466 | "source": "", 467 | "id": "d9110111c449d185" 468 | } 469 | ], 470 | "metadata": { 471 | "kernelspec": { 472 | "display_name": "Python 3", 473 | "language": "python", 474 | "name": "python3" 475 | }, 476 | "language_info": { 477 | "codemirror_mode": { 478 | "name": "ipython", 479 | "version": 2 480 | }, 481 | "file_extension": ".py", 482 | "mimetype": "text/x-python", 483 | "name": "python", 484 | "nbconvert_exporter": "python", 485 | "pygments_lexer": "ipython2", 486 | "version": "2.7.6" 487 | } 488 | }, 489 | "nbformat": 4, 490 | "nbformat_minor": 5 491 | } 492 | -------------------------------------------------------------------------------- /tutorial_training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "metadata": {}, 5 | "cell_type": "markdown", 6 | "source": [ 7 | "# Tutorial: Training a model\n", 8 | "In this tutorial, we will train an event-based state-space model on a reduced version of the [Spiking Heidelberg Digits](https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/) dataset.\n", 9 | "For training on larger datasets or multiple GPUs, we recommend using the training script `run_training.py` instead.\n", 10 | "\n", 11 | "## Setup\n", 12 | "\n", 13 | "Install and load the important modules and configuration. To install required packages, please do \n", 14 | "```\n", 15 | "pip3 install requirements.txt\n", 16 | "```\n", 17 | "\n", 18 | "Directories for loading datasets, model checkpoints and saving results are defined in the configuration file `system/local.yaml`.\n", 19 | "Please set your directories accordingly." 20 | ], 21 | "id": "4d02d51dcadfcfb" 22 | }, 23 | { 24 | "metadata": {}, 25 | "cell_type": "markdown", 26 | "source": [ 27 | "## Data loading\n", 28 | "The SHD dataset contains 20 classes, digits from 0 to 9 in both German and English. \n", 29 | "We will use a reduced version of the dataset containing only two digits to train the model to non-trivial performance in reasonable time even on CPUs.\n", 30 | "\n", 31 | "[Download the training and test dataset](https://zenkelab.org/datasets/) and unpack the archives to `./data/`." 32 | ], 33 | "id": "df7ee68ff3e429ed" 34 | }, 35 | { 36 | "metadata": {}, 37 | "cell_type": "code", 38 | "source": [ 39 | "from torch.utils.data import Dataset, DataLoader, random_split\n", 40 | "import h5py\n", 41 | "import numpy as np\n", 42 | "\n", 43 | "class SpikingHeidelbergDigits(Dataset):\n", 44 | " def __init__(self, path_to_file):\n", 45 | " self.num_classes = 2\n", 46 | " self.num_channels = 700\n", 47 | " self.path_to_file = path_to_file\n", 48 | " \n", 49 | " # load the dataset\n", 50 | " with h5py.File(path_to_file, 'r') as f:\n", 51 | " self.channels = f['spikes']['units'][:]\n", 52 | " self.timesteps = f['spikes']['times'][:]\n", 53 | " self.labels = f['labels'][:]\n", 54 | " \n", 55 | " # filter the dataset to contain only two classes\n", 56 | " mask = (self.labels == 0) | (self.labels == 1)\n", 57 | " self.channels = self.channels[mask]\n", 58 | " self.timesteps = self.timesteps[mask]\n", 59 | " self.labels = self.labels[mask]\n", 60 | " \n", 61 | " def __len__(self):\n", 62 | " return len(self.labels)\n", 63 | " \n", 64 | " def __getitem__(self, idx):\n", 65 | " # create tonic-like structured arrays\n", 66 | " dtype = np.dtype([(\"t\", int), (\"x\", int), (\"p\", int)])\n", 67 | " struct_arr = np.empty_like(self.channels[idx], dtype=dtype)\n", 68 | " \n", 69 | " # yield timesteps in milliseconds\n", 70 | " timesteps = self.timesteps[idx] * 1e6\n", 71 | " \n", 72 | " struct_arr['t'] = timesteps\n", 73 | " struct_arr['x'] = self.channels[idx]\n", 74 | " struct_arr['p'] = 1\n", 75 | " \n", 76 | " # one-hot encoding of labels (required for CutMix augmentation)\n", 77 | " label = np.eye(self.num_classes)[self.labels[idx]].astype(np.int32)\n", 78 | " \n", 79 | " return struct_arr, label" 80 | ], 81 | "id": "f9883d23c86e5bcd", 82 | "outputs": [], 83 | "execution_count": null 84 | }, 85 | { 86 | "metadata": {}, 87 | "cell_type": "code", 88 | "source": [ 89 | "# Load the training and test dataset\n", 90 | "train_dataset = SpikingHeidelbergDigits('data/shd_train.h5')\n", 91 | "test_dataset = SpikingHeidelbergDigits('data/shd_test.h5')" 92 | ], 93 | "id": "3be0429979f96a3f", 94 | "outputs": [], 95 | "execution_count": null 96 | }, 97 | { 98 | "metadata": {}, 99 | "cell_type": "markdown", 100 | "source": "Check the length of the datasets to check if the data loading was successful.", 101 | "id": "cf72529578541b9" 102 | }, 103 | { 104 | "metadata": {}, 105 | "cell_type": "code", 106 | "source": [ 107 | "print(f\"Number of training samples: {len(train_dataset)}\")\n", 108 | "print(f\"Number of test samples: {len(test_dataset)}\")" 109 | ], 110 | "id": "ec059aefa5d3408", 111 | "outputs": [], 112 | "execution_count": null 113 | }, 114 | { 115 | "metadata": {}, 116 | "cell_type": "markdown", 117 | "source": "Now, create a validation set by randomly splitting the training dataset, and create data loaders for training, validation, and test datasets.", 118 | "id": "ee746256534411df" 119 | }, 120 | { 121 | "metadata": {}, 122 | "cell_type": "code", 123 | "source": [ 124 | "# Split the training dataset into training and validation\n", 125 | "train_dataset, val_dataset = random_split(train_dataset, [int(0.8*len(train_dataset)), len(train_dataset) - int(0.8*len(train_dataset))])\n", 126 | "\n", 127 | "# Create data loaders\n", 128 | "from event_ssm.dataloading import event_stream_collate_fn\n", 129 | "from functools import partial\n", 130 | "\n", 131 | "collate_fn = partial(event_stream_collate_fn, resolution=(700,), pad_unit=8192)\n", 132 | "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True, collate_fn=collate_fn)\n", 133 | "val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, drop_last=True, collate_fn=collate_fn)\n", 134 | "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)" 135 | ], 136 | "id": "1ab24a1d63c4c194", 137 | "outputs": [], 138 | "execution_count": null 139 | }, 140 | { 141 | "metadata": {}, 142 | "cell_type": "markdown", 143 | "source": [ 144 | "## Model definition\n", 145 | "We use the [hydra](https://hydra.cc/docs/intro/) package for efficient configuration management. Define the model configuration in a config file in the `configs` directory." 146 | ], 147 | "id": "acc88e3270fda10b" 148 | }, 149 | { 150 | "metadata": {}, 151 | "cell_type": "code", 152 | "source": [ 153 | "from hydra import compose, initialize\n", 154 | "from omegaconf import OmegaConf, open_dict\n", 155 | "\n", 156 | "with initialize(version_base=None, config_path=\"configs\", job_name=\"training tutorial\"):\n", 157 | " cfg = compose(config_name=\"base\", overrides=[\"task=tutorial\"])\n", 158 | "\n", 159 | "with open_dict(cfg): \n", 160 | " # optax updates the schedule every iteration and not every epoch\n", 161 | " cfg.optimizer.total_steps = cfg.training.num_epochs * len(train_loader) // cfg.optimizer.accumulation_steps\n", 162 | " cfg.optimizer.warmup_steps = cfg.optimizer.warmup_epochs * len(train_loader) // cfg.optimizer.accumulation_steps\n", 163 | " \n", 164 | " # scale learning rate by batch size\n", 165 | " cfg.optimizer.ssm_lr = cfg.optimizer.ssm_base_lr * cfg.training.per_device_batch_size * cfg.optimizer.accumulation_steps\n", 166 | "\n", 167 | "print(OmegaConf.to_yaml(cfg))" 168 | ], 169 | "id": "810edc0798ad7622", 170 | "outputs": [], 171 | "execution_count": null 172 | }, 173 | { 174 | "metadata": {}, 175 | "cell_type": "markdown", 176 | "source": "Now, create the model using the configuration defined above.", 177 | "id": "2ca62d33ebabfdb2" 178 | }, 179 | { 180 | "metadata": {}, 181 | "cell_type": "code", 182 | "source": [ 183 | "from event_ssm.ssm import init_S5SSM\n", 184 | "from event_ssm.seq_model import BatchClassificationModel\n", 185 | "\n", 186 | "ssm_init_fn = init_S5SSM(**cfg.model.ssm_init)\n", 187 | "model = BatchClassificationModel(\n", 188 | " ssm=ssm_init_fn,\n", 189 | " num_classes=test_dataset.num_classes,\n", 190 | " num_embeddings=test_dataset.num_channels,\n", 191 | " **cfg.model.ssm,\n", 192 | ")" 193 | ], 194 | "id": "83e2062ea8b4fe02", 195 | "outputs": [], 196 | "execution_count": null 197 | }, 198 | { 199 | "metadata": {}, 200 | "cell_type": "markdown", 201 | "source": [ 202 | "\n", 203 | "Initialize the training state by feeding a dummy input" 204 | ], 205 | "id": "c855737447f70896" 206 | }, 207 | { 208 | "metadata": {}, 209 | "cell_type": "code", 210 | "source": [ 211 | "import jax\n", 212 | "from event_ssm.train_utils import init_model_state\n", 213 | "\n", 214 | "# pick the first batch from the training loader\n", 215 | "batch = next(iter(train_loader))\n", 216 | "inputs, targets, timesteps, lengths = batch\n", 217 | "\n", 218 | "# initialize the training state\n", 219 | "key = jax.random.PRNGKey(cfg.seed)\n", 220 | "state = init_model_state(key, model, inputs, timesteps, lengths, cfg.optimizer)" 221 | ], 222 | "id": "d4def1c65952a8ba", 223 | "outputs": [], 224 | "execution_count": null 225 | }, 226 | { 227 | "metadata": {}, 228 | "cell_type": "markdown", 229 | "source": [ 230 | "## Inspect the model\n", 231 | "The model parameters are accessible as part of the training state. \n", 232 | "We will look into the spectrum of the recurrent operator here.\n", 233 | "The model was initialized with a single stage of blocks." 234 | ], 235 | "id": "424bce6010abb8f1" 236 | }, 237 | { 238 | "metadata": {}, 239 | "cell_type": "code", 240 | "source": [ 241 | "def get_spectrum(state):\n", 242 | " params = state.params['encoder']['stages_0']\n", 243 | " lambda_bar = []\n", 244 | " time_scales = []\n", 245 | " for name, sequence_layer in params.items():\n", 246 | " # read lambda parameters\n", 247 | " Lambda_im = sequence_layer['S5SSM_0']['Lambda_im']\n", 248 | " Lambda_re = sequence_layer['S5SSM_0']['Lambda_re']\n", 249 | " \n", 250 | " # read and compute delta and Lambda\n", 251 | " delta = np.exp(sequence_layer['S5SSM_0']['log_step'][:, 0])\n", 252 | " Lambda = Lambda_re + 1j * Lambda_im\n", 253 | " \n", 254 | " # compute lambda_bar and time scales\n", 255 | " lambda_bar.append(np.exp(Lambda * delta))\n", 256 | " time_scales.append(1 / np.abs(Lambda) / delta)\n", 257 | " return lambda_bar, time_scales\n", 258 | "spectrum, time_scales = get_spectrum(state)" 259 | ], 260 | "id": "5a1602ed4a962265", 261 | "outputs": [], 262 | "execution_count": null 263 | }, 264 | { 265 | "metadata": {}, 266 | "cell_type": "markdown", 267 | "source": "Plot the spectrum of the recurrent operator and the corresponding time scales upon initialization.", 268 | "id": "2d1377b48ebf3728" 269 | }, 270 | { 271 | "metadata": {}, 272 | "cell_type": "code", 273 | "source": [ 274 | "import matplotlib.pyplot as plt\n", 275 | "\n", 276 | "def plot_spectrum(spectrum):\n", 277 | " fig, axes = plt.subplots(1, 6, figsize=(len(spectrum) * 4, 4))\n", 278 | " # draw the unit circle\n", 279 | " theta = np.linspace(0, 2 * np.pi, 100) # 100 points from 0 to 2*pi\n", 280 | " x = np.cos(theta)\n", 281 | " y = np.sin(theta)\n", 282 | " \n", 283 | " # plot the spectrum\n", 284 | " for i, (ax, layer) in enumerate(zip(axes, spectrum)):\n", 285 | " ax.plot(x, y, 'r', linewidth=1)\n", 286 | " ax.scatter(np.real(layer), np.imag(layer), marker='o', alpha=0.8)\n", 287 | " \n", 288 | " # format axis\n", 289 | " ax.set_title(f'Layer {i}')\n", 290 | " ax.set_aspect('equal', adjustable='box')\n", 291 | " ax.set_xlim(-1.1, 1.1)\n", 292 | " ax.set_ylim(-1.1, 1.1)\n", 293 | " \n", 294 | " plt.tight_layout()\n", 295 | " plt.show()\n", 296 | " \n", 297 | "plot_spectrum(spectrum)" 298 | ], 299 | "id": "9ff987be826d7314", 300 | "outputs": [], 301 | "execution_count": null 302 | }, 303 | { 304 | "metadata": {}, 305 | "cell_type": "code", 306 | "source": [ 307 | "def plot_time_scales(time_scales):\n", 308 | " log_scales = np.log2(np.stack(time_scales).flatten())\n", 309 | " fig, ax = plt.subplots(1, 1, figsize=(6, 6))\n", 310 | " ax.hist(log_scales)\n", 311 | " \n", 312 | " # format axis\n", 313 | " max_scale = np.max(np.ceil(log_scales))\n", 314 | " min_scale = np.min(np.floor(log_scales))\n", 315 | " ax.set_xlim((min_scale, max_scale))\n", 316 | " xticks = np.arange(1 + max_scale - min_scale) + min_scale\n", 317 | " ax.set_xticks(xticks, (2 ** xticks).astype(np.int32))\n", 318 | " ax.set_title('Distribution of time scales')\n", 319 | " ax.set_xlabel('Time scale')\n", 320 | " ax.set_ylabel('Count')\n", 321 | " plt.show()\n", 322 | " \n", 323 | "plot_time_scales(time_scales)" 324 | ], 325 | "id": "c1779ae6f5f72b44", 326 | "outputs": [], 327 | "execution_count": null 328 | }, 329 | { 330 | "metadata": {}, 331 | "cell_type": "markdown", 332 | "source": [ 333 | "## Train the model\n", 334 | "For training, we implemented a trainer module that makes training as easy as possible. The trainer module hides some boilerplate code for training from the user and provides a simple interface to train the model. It loops through the data loader, computes the loss, and updates the model parameters. Therefore, we need to define training_step and validation_step functions that the loop calls upon the model. These are implemented already, and can be used here." 335 | ], 336 | "id": "b4970c69f459df1d" 337 | }, 338 | { 339 | "metadata": {}, 340 | "cell_type": "code", 341 | "source": [ 342 | "from event_ssm.train_utils import training_step, evaluation_step\n", 343 | "from event_ssm.trainer import TrainerModule\n", 344 | "\n", 345 | "# just-in-time compile the training and evaluation functions\n", 346 | "train_step = jax.jit(training_step)\n", 347 | "eval_step = jax.jit(evaluation_step)\n", 348 | "\n", 349 | "# initialize the trainer module\n", 350 | "num_devices = 1\n", 351 | "trainer = TrainerModule(\n", 352 | " train_state=state,\n", 353 | " training_step_fn=train_step,\n", 354 | " evaluation_step_fn=eval_step,\n", 355 | " world_size=num_devices,\n", 356 | " config=cfg,\n", 357 | ")" 358 | ], 359 | "id": "61ab72052c47f47c", 360 | "outputs": [], 361 | "execution_count": null 362 | }, 363 | { 364 | "metadata": {}, 365 | "cell_type": "markdown", 366 | "source": [ 367 | "We are now ready to start the training loop. \n", 368 | "\n", 369 | "**Note:** JAX compiles your program just-in-time (JIT) to optimize performance. This means that the first iteration of the training loop will be slower than the following ones. " 370 | ], 371 | "id": "d66a413bc0ac7d2b" 372 | }, 373 | { 374 | "metadata": {}, 375 | "cell_type": "code", 376 | "source": [ 377 | "# generate random key for dropout\n", 378 | "key, dropout_key = jax.random.split(key)\n", 379 | "\n", 380 | "# train the model\n", 381 | "trainer.train_model(\n", 382 | " train_loader=train_loader,\n", 383 | " val_loader=val_loader,\n", 384 | " test_loader=test_loader,\n", 385 | " dropout_key=dropout_key\n", 386 | ")" 387 | ], 388 | "id": "9d5ab8aa623db697", 389 | "outputs": [], 390 | "execution_count": null 391 | }, 392 | { 393 | "metadata": {}, 394 | "cell_type": "markdown", 395 | "source": [ 396 | "## Inspect the trained model\n", 397 | "We now have a trained toy model on the SHD dataset.\n", 398 | "Let's look into the spectrum of the recurrent operator after training." 399 | ], 400 | "id": "a929b74f8ce235e5" 401 | }, 402 | { 403 | "metadata": {}, 404 | "cell_type": "code", 405 | "source": [ 406 | "spectrum, time_scales = get_spectrum(trainer.train_state)\n", 407 | "plot_spectrum(spectrum)\n", 408 | "plot_time_scales(time_scales)" 409 | ], 410 | "id": "3281a08743303429", 411 | "outputs": [], 412 | "execution_count": null 413 | }, 414 | { 415 | "metadata": {}, 416 | "cell_type": "markdown", 417 | "source": [ 418 | "## Assignment\n", 419 | "The function `apply_ssm` in `event_ssm/ssm.py` implements the recurrent operator with an associative scan. On highly parallel GPUs, this can speed up training on very long sequences. \n", 420 | "On CPUs however, the overhead of the scan operation can slow down training. \n", 421 | "Your task is to implement a CPU-friendly version of the recurrent operator in `event_ssm/ssm.py` and compare the training time with the original implementation.\n", 422 | "We suggest to implement a step-by-step recurrence with [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) instead of the currenlty used [`jax.lax.associative_scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.associative_scan.html) for this purpose." 423 | ], 424 | "id": "28ad17b8c7b61230" 425 | } 426 | ], 427 | "metadata": { 428 | "kernelspec": { 429 | "display_name": "Python 3", 430 | "language": "python", 431 | "name": "python3" 432 | }, 433 | "language_info": { 434 | "codemirror_mode": { 435 | "name": "ipython", 436 | "version": 2 437 | }, 438 | "file_extension": ".py", 439 | "mimetype": "text/x-python", 440 | "name": "python", 441 | "nbconvert_exporter": "python", 442 | "pygments_lexer": "ipython2", 443 | "version": "2.7.6" 444 | } 445 | }, 446 | "nbformat": 4, 447 | "nbformat_minor": 5 448 | } 449 | --------------------------------------------------------------------------------