├── .gitignore ├── LICENSE ├── README.md ├── code ├── common.py ├── datasets │ ├── __init__.py │ ├── cifar10_dvs.py │ ├── common.py │ ├── dvs128_gesture.py │ └── n_mnist.py ├── encoders.py ├── losses.py ├── networks.py ├── nodes.py ├── operations.py └── train.py ├── train_cifar10.sh ├── train_dvsc10.sh ├── train_dvsg.sh └── train_nmnist.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | .idea 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # pytype static type analyzer 138 | .pytype/ 139 | 140 | # Cython debug symbols 141 | cython_debug/ 142 | 143 | ### VisualStudioCode template 144 | .vscode/* 145 | !.vscode/settings.json 146 | !.vscode/tasks.json 147 | !.vscode/launch.json 148 | !.vscode/extensions.json 149 | *.code-workspace 150 | 151 | # Local History for Visual Studio Code 152 | .history/ 153 | 154 | ### JetBrains template 155 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 156 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 157 | 158 | # User-specific stuff 159 | .idea/**/workspace.xml 160 | .idea/**/tasks.xml 161 | .idea/**/usage.statistics.xml 162 | .idea/**/dictionaries 163 | .idea/**/shelf 164 | 165 | # Generated files 166 | .idea/**/contentModel.xml 167 | 168 | # Sensitive or high-churn files 169 | .idea/**/dataSources/ 170 | .idea/**/dataSources.ids 171 | .idea/**/dataSources.local.xml 172 | .idea/**/sqlDataSources.xml 173 | .idea/**/dynamic.xml 174 | .idea/**/uiDesigner.xml 175 | .idea/**/dbnavigator.xml 176 | 177 | # Gradle 178 | .idea/**/gradle.xml 179 | .idea/**/libraries 180 | 181 | # Gradle and Maven with auto-import 182 | # When using Gradle or Maven with auto-import, you should exclude module files, 183 | # since they will be recreated, and may cause churn. Uncomment if using 184 | # auto-import. 185 | # .idea/artifacts 186 | # .idea/compiler.xml 187 | # .idea/jarRepositories.xml 188 | # .idea/modules.xml 189 | # .idea/*.iml 190 | # .idea/modules 191 | # *.iml 192 | # *.ipr 193 | 194 | # CMake 195 | cmake-build-*/ 196 | 197 | # Mongo Explorer plugin 198 | .idea/**/mongoSettings.xml 199 | 200 | # File-based project format 201 | *.iws 202 | 203 | # IntelliJ 204 | out/ 205 | 206 | # mpeltonen/sbt-idea plugin 207 | .idea_modules/ 208 | 209 | # JIRA plugin 210 | atlassian-ide-plugin.xml 211 | 212 | # Cursive Clojure plugin 213 | .idea/replstate.xml 214 | 215 | # Crashlytics plugin (for Android Studio and IntelliJ) 216 | com_crashlytics_export_strings.xml 217 | crashlytics.properties 218 | crashlytics-build.properties 219 | fabric.properties 220 | 221 | # Editor-based Rest Client 222 | .idea/httpRequests 223 | 224 | # Android studio 3.1+ serialized cache file 225 | .idea/caches/build_file_checksums.ser 226 | 227 | .idea/.gitignore 228 | .idea/darts_snn.iml 229 | .idea/deployment.xml 230 | .idea/inspectionProfiles/ 231 | .idea/misc.xml 232 | .idea/modules.xml 233 | .idea/vcs.xml 234 | .idea/other.xml 235 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 FloyedShen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Backpropagation with Biologically Plausible Spatio-Temporal Adjustment For Training Deep Spiking Neural Networks 2 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.6489856.svg)](https://doi.org/10.5281/zenodo.6489856) 3 | 4 | This repository contains code from our paper [**Backpropagation with Biologically Plausible Spatio-Temporal Adjustment For Training Deep Spiking Neural Networks**] published in Cell Patterns. https://www.cell.com/patterns/fulltext/S2666-3899(22)00119-2. If you use our code or refer to this project, please cite this paper. 5 | 6 | ## Requirments 7 | 8 | * numpy 9 | * scipy 10 | * pytorch >= 1.7.0 11 | * torchvision 12 | 13 | 14 | ## Data preparation 15 | 16 | First modify the ```DATA_DIR='path/to/datasets``` in ```code/datasets/__init__.py``` to the root directory of your datasets. 17 | Neuromorphic datasets **NMNIST**, **DVS-Gesture** and **DVS-CIFAR10** need to be manually downloaded and placed under the ```/path/to/datasets/DVS/*``` 18 | 19 | 20 | ``` 21 | /path/to/datasets/ 22 | DVS/ 23 | DVS_Cifar10/ 24 | DVS_Gesture/ 25 | NMNIST/ 26 | ``` 27 | 28 | ## Train 29 | 30 | Run training scripts corresponding to different datasets. 31 | 32 | For example, training and validating the proposed method on the MNIST dataset: 33 | 34 | ```bash 35 | bash ./train_dvsg.sh 36 | ``` 37 | -------------------------------------------------------------------------------- /code/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | import torch.utils 8 | import torch.nn.functional as F 9 | import torchvision.datasets as dset 10 | from torchvision import transforms 11 | import torch.backends.cudnn as cudnn 12 | 13 | thresh = 0.5 # neuronal threshold 14 | lens = 0.5 # hyper-parameters of approximate function 15 | 16 | 17 | class ActFun(torch.autograd.Function): 18 | 19 | @staticmethod 20 | def forward(ctx, inputs): 21 | ctx.save_for_backward(inputs) 22 | return inputs.gt(thresh).float() 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | inputs, = ctx.saved_tensors 27 | grad_inputs = grad_output.clone() 28 | temp = abs(inputs - thresh) < lens 29 | return grad_inputs * temp.float() 30 | # grad = torch.exp((thresh - inputs)) / ((torch.exp((thresh - inputs)) + 1) ** 2) 31 | # return grad * grad_output 32 | 33 | 34 | class AverageMeter(object): 35 | def __init__(self): 36 | self.reset() 37 | 38 | def reset(self): 39 | self.avg = 0 40 | self.sum = 0 41 | self.cnt = 0 42 | 43 | def update(self, val, n=1): 44 | self.sum += val * n 45 | self.cnt += n 46 | self.avg = self.sum / self.cnt 47 | 48 | 49 | def accuracy(output, target, topk=(1,)): 50 | """Compute the top1 and top5 accuracy 51 | 52 | """ 53 | maxk = max(topk) 54 | batch_size = target.size(0) 55 | 56 | # Return the k largest elements of the given input tensor 57 | # along a given dimension -> N * k 58 | _, pred = output.topk(maxk, 1, True, True) 59 | pred = pred.t() 60 | 61 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 62 | 63 | res = [] 64 | for k in topk: 65 | correct_k = correct[:k].reshape(-1).float().sum(0) 66 | res.append(correct_k.mul_(100.0 / batch_size)) 67 | return res -------------------------------------------------------------------------------- /code/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['get_dvsc10_data', 'get_dvsg_data', 'get_nmnist_data'] 2 | 3 | from .cifar10_dvs import CIFAR10DVS 4 | from .dvs128_gesture import DVS128Gesture 5 | from .n_mnist import NMNIST 6 | 7 | import os 8 | import sys 9 | import numpy as np 10 | from PIL import Image 11 | import torch 12 | import torch.utils 13 | from torchvision import transforms 14 | 15 | ''' 16 | https://github.com/fangwei123456/spikingjelly/tree/master/spikingjelly/datasets 17 | ''' 18 | 19 | DATA_DIR = '/data/datasets' 20 | 21 | 22 | def get_dvsg_data(batch_size, step): 23 | train_transform = transforms.Compose([lambda x: torch.tensor(x), 24 | transforms.RandomCrop(128, padding=16), 25 | # transforms.RandomHorizontalFlip(), 26 | transforms.RandomRotation(5)]) 27 | train_datasets = DVS128Gesture(os.path.join(DATA_DIR, 'DVS/DVS_Gesture'), train=True, transform=train_transform, 28 | data_type='frame', split_by='number', frames_number=step) 29 | test_datasets = DVS128Gesture(os.path.join(DATA_DIR, 'DVS/DVS_Gesture'), train=False, 30 | data_type='frame', split_by='number', frames_number=step) 31 | 32 | train_loader = torch.utils.data.DataLoader( 33 | dataset=train_datasets, 34 | batch_size=batch_size, 35 | shuffle=True, 36 | pin_memory=True, 37 | drop_last=True, 38 | num_workers=8 39 | ) 40 | test_loader = torch.utils.data.DataLoader( 41 | dataset=test_datasets, 42 | batch_size=batch_size, 43 | shuffle=False, 44 | drop_last=False, 45 | pin_memory=True, 46 | num_workers=2 47 | ) 48 | 49 | return train_loader, test_loader, None, None 50 | 51 | 52 | def get_dvsc10_data(batch_size, step): 53 | train_transform = transforms.Compose([lambda x: torch.tensor(x), 54 | transforms.RandomCrop(128, padding=16), 55 | transforms.RandomHorizontalFlip(), 56 | transforms.RandomRotation(15)]) 57 | 58 | train_datasets = CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), transform=train_transform, 59 | data_type='frame', split_by='number', frames_number=step) 60 | 61 | test_datasets = CIFAR10DVS(os.path.join(DATA_DIR, 'DVS/DVS_Cifar10'), 62 | data_type='frame', split_by='number', frames_number=step) 63 | 64 | num_train = len(train_datasets) 65 | num_per_cls = num_train // 10 66 | indices_train, indices_test = [], [] 67 | portion = .9 68 | for i in range(10): 69 | indices_train.extend(list(range(i * num_per_cls, int(i * num_per_cls + num_per_cls * portion)))) 70 | indices_test.extend(list(range(int(i * num_per_cls + num_per_cls * portion), (i + 1) * num_per_cls))) 71 | 72 | train_loader = torch.utils.data.DataLoader( 73 | train_datasets, batch_size=batch_size, 74 | sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_train), 75 | pin_memory=True, drop_last=False, num_workers=4 76 | ) 77 | 78 | test_loader = torch.utils.data.DataLoader( 79 | test_datasets, batch_size=batch_size, 80 | sampler=torch.utils.data.sampler.SubsetRandomSampler(indices_test), 81 | pin_memory=True, drop_last=False, num_workers=2 82 | ) 83 | 84 | return train_loader, test_loader, None, None 85 | 86 | 87 | def get_nmnist_data(batch_size, step): 88 | train_transform = transforms.Compose([lambda x: torch.tensor(x), 89 | transforms.RandomCrop(34, padding=4), 90 | transforms.RandomRotation(10)]) 91 | train_datasets = NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), train=True, 92 | data_type='frame', split_by='number', frames_number=step) 93 | test_datasets = NMNIST(os.path.join(DATA_DIR, 'DVS/NMNIST'), train=False, 94 | data_type='frame', split_by='number', frames_number=step) 95 | 96 | train_loader = torch.utils.data.DataLoader( 97 | dataset=train_datasets, 98 | batch_size=batch_size, 99 | shuffle=True, 100 | pin_memory=True, 101 | drop_last=True, 102 | num_workers=8 103 | ) 104 | test_loader = torch.utils.data.DataLoader( 105 | dataset=test_datasets, 106 | batch_size=batch_size, 107 | shuffle=False, 108 | drop_last=False, 109 | pin_memory=True, 110 | num_workers=2 111 | ) 112 | 113 | return train_loader, test_loader, None, None 114 | 115 | -------------------------------------------------------------------------------- /code/datasets/cifar10_dvs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple 2 | import numpy as np 3 | from datasets.common import * 4 | from torchvision.datasets.utils import extract_archive 5 | import os 6 | import multiprocessing 7 | from concurrent.futures import ThreadPoolExecutor 8 | import time 9 | 10 | # https://github.com/jackd/events-tfds/blob/master/events_tfds/data_io/aedat.py 11 | 12 | 13 | EVT_DVS = 0 # DVS event type 14 | EVT_APS = 1 # APS event 15 | 16 | 17 | def read_bits(arr, mask=None, shift=None): 18 | if mask is not None: 19 | arr = arr & mask 20 | if shift is not None: 21 | arr = arr >> shift 22 | return arr 23 | 24 | 25 | y_mask = 0x7FC00000 26 | y_shift = 22 27 | 28 | x_mask = 0x003FF000 29 | x_shift = 12 30 | 31 | polarity_mask = 0x800 32 | polarity_shift = 11 33 | 34 | valid_mask = 0x80000000 35 | valid_shift = 31 36 | 37 | 38 | def skip_header(fp): 39 | p = 0 40 | lt = fp.readline() 41 | ltd = lt.decode().strip() 42 | while ltd and ltd[0] == "#": 43 | p += len(lt) 44 | lt = fp.readline() 45 | try: 46 | ltd = lt.decode().strip() 47 | except UnicodeDecodeError: 48 | break 49 | return p 50 | 51 | 52 | def load_raw_events(fp, 53 | bytes_skip=0, 54 | bytes_trim=0, 55 | filter_dvs=False, 56 | times_first=False): 57 | p = skip_header(fp) 58 | fp.seek(p + bytes_skip) 59 | data = fp.read() 60 | if bytes_trim > 0: 61 | data = data[:-bytes_trim] 62 | data = np.fromstring(data, dtype='>u4') 63 | if len(data) % 2 != 0: 64 | print(data[:20:2]) 65 | print('---') 66 | print(data[1:21:2]) 67 | raise ValueError('odd number of data elements') 68 | raw_addr = data[::2] 69 | timestamp = data[1::2] 70 | if times_first: 71 | timestamp, raw_addr = raw_addr, timestamp 72 | if filter_dvs: 73 | valid = read_bits(raw_addr, valid_mask, valid_shift) == EVT_DVS 74 | timestamp = timestamp[valid] 75 | raw_addr = raw_addr[valid] 76 | return timestamp, raw_addr 77 | 78 | 79 | def parse_raw_address(addr, 80 | x_mask=x_mask, 81 | x_shift=x_shift, 82 | y_mask=y_mask, 83 | y_shift=y_shift, 84 | polarity_mask=polarity_mask, 85 | polarity_shift=polarity_shift): 86 | polarity = read_bits(addr, polarity_mask, polarity_shift).astype(np.bool) 87 | x = read_bits(addr, x_mask, x_shift) 88 | y = read_bits(addr, y_mask, y_shift) 89 | return x, y, polarity 90 | 91 | 92 | def load_events( 93 | fp, 94 | filter_dvs=False, 95 | # bytes_skip=0, 96 | # bytes_trim=0, 97 | # times_first=False, 98 | **kwargs): 99 | timestamp, addr = load_raw_events( 100 | fp, 101 | filter_dvs=filter_dvs, 102 | # bytes_skip=bytes_skip, 103 | # bytes_trim=bytes_trim, 104 | # times_first=times_first 105 | ) 106 | x, y, polarity = parse_raw_address(addr, **kwargs) 107 | return timestamp, x, y, polarity 108 | 109 | 110 | class CIFAR10DVS(NeuromorphicDatasetFolder): 111 | def __init__( 112 | self, 113 | root: str, 114 | data_type: str = 'event', 115 | frames_number: int = None, 116 | split_by: str = None, 117 | duration: int = None, 118 | transform: Optional[Callable] = None, 119 | target_transform: Optional[Callable] = None, 120 | ) -> None: 121 | ''' 122 | :param root: root path of the dataset 123 | :type root: str 124 | :param data_type: `event` or `frame` 125 | :type data_type: str 126 | :param frames_number: the integrated frame number 127 | :type frames_number: int 128 | :param split_by: `time` or `number` 129 | :type split_by: str 130 | :param duration: the time duration of each frame 131 | :type duration: int 132 | :param transform: a function/transform that takes in 133 | a sample and returns a transformed version. 134 | E.g, ``transforms.RandomCrop`` for images. 135 | :type transform: callable 136 | :param target_transform: a function/transform that takes 137 | in the target and transforms it. 138 | :type target_transform: callable 139 | 140 | If ``data_type == 'event'`` 141 | the sample in this dataset is a dict whose keys are ['t', 'x', 'y', 'p'] and values are ``numpy.ndarray``. 142 | 143 | If ``data_type == 'frame'`` and ``frames_number`` is not ``None`` 144 | events will be integrated to frames with fixed frames number. ``split_by`` will define how to split events. 145 | See :class:`cal_fixed_frames_number_segment_index` for 146 | more details. 147 | 148 | If ``data_type == 'frame'``, ``frames_number`` is ``None``, and ``duration`` is not ``None`` 149 | events will be integrated to frames with fixed time duration. 150 | 151 | ''' 152 | super().__init__(root, None, data_type, frames_number, split_by, duration, transform, 153 | target_transform) 154 | 155 | @staticmethod 156 | def resource_url_md5() -> list: 157 | ''' 158 | :return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5 159 | :rtype: list 160 | ''' 161 | return [ 162 | ('airplane.zip', 'https://ndownloader.figshare.com/files/7712788', '0afd5c4bf9ae06af762a77b180354fdd'), 163 | ('automobile.zip', 'https://ndownloader.figshare.com/files/7712791', '8438dfeba3bc970c94962d995b1b9bdd'), 164 | ('bird.zip', 'https://ndownloader.figshare.com/files/7712794', 'a9c207c91c55b9dc2002dc21c684d785'), 165 | ('cat.zip', 'https://ndownloader.figshare.com/files/7712812', '52c63c677c2b15fa5146a8daf4d56687'), 166 | ('deer.zip', 'https://ndownloader.figshare.com/files/7712815', 'b6bf21f6c04d21ba4e23fc3e36c8a4a3'), 167 | ('dog.zip', 'https://ndownloader.figshare.com/files/7712818', 'f379ebdf6703d16e0a690782e62639c3'), 168 | ('frog.zip', 'https://ndownloader.figshare.com/files/7712842', 'cad6ed91214b1c7388a5f6ee56d08803'), 169 | ('horse.zip', 'https://ndownloader.figshare.com/files/7712851', 'e7cbbf77bec584ffbf913f00e682782a'), 170 | ('ship.zip', 'https://ndownloader.figshare.com/files/7712836', '41c7bd7d6b251be82557c6cce9a7d5c9'), 171 | ('truck.zip', 'https://ndownloader.figshare.com/files/7712839', '89f3922fd147d9aeff89e76a2b0b70a7') 172 | ] 173 | 174 | @staticmethod 175 | def downloadable() -> bool: 176 | ''' 177 | :return: Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually 178 | :rtype: bool 179 | ''' 180 | return True 181 | 182 | @staticmethod 183 | def extract_downloaded_files(download_root: str, extract_root: str): 184 | ''' 185 | :param download_root: Root directory path which saves downloaded dataset files 186 | :type download_root: str 187 | :param extract_root: Root directory path which saves extracted files from downloaded files 188 | :type extract_root: str 189 | :return: None 190 | 191 | This function defines how to extract download files. 192 | ''' 193 | with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 10)) as tpe: 194 | for zip_file in os.listdir(download_root): 195 | zip_file = os.path.join(download_root, zip_file) 196 | print(f'Extract [{zip_file}] to [{extract_root}].') 197 | tpe.submit(extract_archive, zip_file, extract_root) 198 | 199 | @staticmethod 200 | def load_origin_data(file_name: str) -> Dict: 201 | ''' 202 | :param file_name: path of the events file 203 | :type file_name: str 204 | :return: a dict whose keys are ['t', 'x', 'y', 'p'] and values are ``numpy.ndarray`` 205 | :rtype: Dict 206 | 207 | This function defines how to read the origin binary data. 208 | ''' 209 | with open(file_name, 'rb') as fp: 210 | t, x, y, p = load_events(fp, 211 | x_mask=0xfE, 212 | x_shift=1, 213 | y_mask=0x7f00, 214 | y_shift=8, 215 | polarity_mask=1, 216 | polarity_shift=None) 217 | # return {'t': t, 'x': 127 - x, 'y': y, 'p': 1 - p.astype(int)} # this will get the same data with http://www2.imse-cnm.csic.es/caviar/MNIST_DVS/dat2mat.m 218 | # see https://github.com/jackd/events-tfds/pull/1 for more details about this problem 219 | return {'t': t, 'x': 127 - y, 'y': 127 - x, 'p': 1 - p.astype(int)} 220 | 221 | @staticmethod 222 | def get_H_W() -> Tuple: 223 | ''' 224 | :return: A tuple ``(H, W)``, where ``H`` is the height of the data and ``W` is the weight of the data. 225 | For example, this function returns ``(128, 128)`` for the DVS128 Gesture dataset. 226 | :rtype: tuple 227 | ''' 228 | return 128, 128 229 | 230 | @staticmethod 231 | def read_aedat_save_to_np(bin_file: str, np_file: str): 232 | events = CIFAR10DVS.load_origin_data(bin_file) 233 | np.savez(np_file, 234 | t=events['t'], 235 | x=events['x'], 236 | y=events['y'], 237 | p=events['p'] 238 | ) 239 | print(f'Save [{bin_file}] to [{np_file}].') 240 | 241 | @staticmethod 242 | def create_events_np_files(extract_root: str, events_np_root: str): 243 | ''' 244 | :param extract_root: Root directory path which saves extracted files from downloaded files 245 | :type extract_root: str 246 | :param events_np_root: Root directory path which saves events files in the ``npz`` format 247 | :type events_np_root: 248 | :return: None 249 | 250 | This function defines how to convert the origin binary data in ``extract_root`` to ``npz`` format and save converted files in ``events_np_root``. 251 | ''' 252 | t_ckp = time.time() 253 | with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 64)) as tpe: 254 | for class_name in os.listdir(extract_root): 255 | aedat_dir = os.path.join(extract_root, class_name) 256 | np_dir = os.path.join(events_np_root, class_name) 257 | os.mkdir(np_dir) 258 | print(f'Mkdir [{np_dir}].') 259 | for bin_file in os.listdir(aedat_dir): 260 | source_file = os.path.join(aedat_dir, bin_file) 261 | target_file = os.path.join(np_dir, os.path.splitext(bin_file)[0] + '.npz') 262 | print(f'Start to convert [{source_file}] to [{target_file}].') 263 | tpe.submit(CIFAR10DVS.read_aedat_save_to_np, source_file, 264 | target_file) 265 | print(f'Used time = [{round(time.time() - t_ckp, 2)}s].') 266 | -------------------------------------------------------------------------------- /code/datasets/common.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import DatasetFolder 2 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple 3 | from abc import abstractmethod 4 | import scipy.io 5 | import struct 6 | import numpy as np 7 | from torchvision.datasets import utils 8 | import torch.utils.data 9 | import os 10 | from concurrent.futures import ThreadPoolExecutor 11 | import time 12 | import multiprocessing 13 | from torchvision import transforms 14 | import torch 15 | from matplotlib import pyplot as plt 16 | import math 17 | 18 | 19 | def play_frame(x: torch.Tensor or np.ndarray, save_gif_to: str = None) -> None: 20 | ''' 21 | :param x: frames with ``shape=[T, 2, H, W]`` 22 | :type x: torch.Tensor or np.ndarray 23 | :param save_gif_to: If ``None``, this function will play the frames. If ``True``, this function will not play the frames 24 | but save frames to a gif file in the directory ``save_gif_to`` 25 | :type save_gif_to: str 26 | :return: None 27 | ''' 28 | if isinstance(x, np.ndarray): 29 | x = torch.from_numpy(x) 30 | to_img = transforms.ToPILImage() 31 | img_tensor = torch.zeros([x.shape[0], 3, x.shape[2], x.shape[3]]) 32 | img_tensor[:, 1] = x[:, 0] 33 | img_tensor[:, 2] = x[:, 1] 34 | if save_gif_to is None: 35 | while True: 36 | for t in range(img_tensor.shape[0]): 37 | plt.imshow(to_img(img_tensor[t])) 38 | plt.pause(0.01) 39 | else: 40 | img_list = [] 41 | for t in range(img_tensor.shape[0]): 42 | img_list.append(to_img(img_tensor[t])) 43 | img_list[0].save(save_gif_to, save_all=True, append_images=img_list[1:], loop=0) 44 | print(f'Save frames to [{save_gif_to}].') 45 | 46 | 47 | def load_matlab_mat(file_name: str) -> Dict: 48 | ''' 49 | :param file_name: path of the matlab's mat file 50 | :type file_name: str 51 | :return: a dict whose keys are ['t', 'x', 'y', 'p'] and values are ``numpy.ndarray`` 52 | :rtype: Dict 53 | ''' 54 | events = scipy.io.loadmat(file_name) 55 | return { 56 | 't': events['ts'].squeeze(), 57 | 'x': events['x'].squeeze(), 58 | 'y': events['y'].squeeze(), 59 | 'p': events['pol'].squeeze() 60 | } 61 | 62 | 63 | def load_aedat_v3(file_name: str) -> Dict: 64 | ''' 65 | :param file_name: path of the aedat v3 file 66 | :type file_name: str 67 | :return: a dict whose keys are ['t', 'x', 'y', 'p'] and values are ``numpy.ndarray`` 68 | :rtype: Dict 69 | 70 | This function is written by referring to https://gitlab.com/inivation/dv/dv-python . It can be used for DVS128 Gesture. 71 | ''' 72 | with open(file_name, 'rb') as bin_f: 73 | # skip ascii header 74 | line = bin_f.readline() 75 | while line.startswith(b'#'): 76 | if line == b'#!END-HEADER\r\n': 77 | break 78 | else: 79 | line = bin_f.readline() 80 | 81 | txyp = { 82 | 't': [], 83 | 'x': [], 84 | 'y': [], 85 | 'p': [] 86 | } 87 | while True: 88 | header = bin_f.read(28) 89 | if not header or len(header) == 0: 90 | break 91 | 92 | # read header 93 | e_type = struct.unpack('H', header[0:2])[0] 94 | e_source = struct.unpack('H', header[2:4])[0] 95 | e_size = struct.unpack('I', header[4:8])[0] 96 | e_offset = struct.unpack('I', header[8:12])[0] 97 | e_tsoverflow = struct.unpack('I', header[12:16])[0] 98 | e_capacity = struct.unpack('I', header[16:20])[0] 99 | e_number = struct.unpack('I', header[20:24])[0] 100 | e_valid = struct.unpack('I', header[24:28])[0] 101 | 102 | data_length = e_capacity * e_size 103 | data = bin_f.read(data_length) 104 | counter = 0 105 | 106 | if e_type == 1: 107 | while data[counter:counter + e_size]: 108 | aer_data = struct.unpack('I', data[counter:counter + 4])[0] 109 | timestamp = struct.unpack('I', data[counter + 4:counter + 8])[0] | e_tsoverflow << 31 110 | x = (aer_data >> 17) & 0x00007FFF 111 | y = (aer_data >> 2) & 0x00007FFF 112 | pol = (aer_data >> 1) & 0x00000001 113 | counter = counter + e_size 114 | txyp['x'].append(x) 115 | txyp['y'].append(y) 116 | txyp['t'].append(timestamp) 117 | txyp['p'].append(pol) 118 | else: 119 | # non-polarity event packet, not implemented 120 | pass 121 | txyp['x'] = np.asarray(txyp['x']) 122 | txyp['y'] = np.asarray(txyp['y']) 123 | txyp['t'] = np.asarray(txyp['t']) 124 | txyp['p'] = np.asarray(txyp['p']) 125 | return txyp 126 | 127 | 128 | def load_ATIS_bin(file_name: str) -> Dict: 129 | ''' 130 | :param file_name: path of the aedat v3 file 131 | :type file_name: str 132 | :return: a dict whose keys are ['t', 'x', 'y', 'p'] and values are ``numpy.ndarray`` 133 | :rtype: Dict 134 | 135 | This function is written by referring to https://github.com/jackd/events-tfds . 136 | 137 | Each ATIS binary example is a separate binary file consisting of a list of events. Each event occupies 40 bits as described below: 138 | bit 39 - 32: Xaddress (in pixels) 139 | bit 31 - 24: Yaddress (in pixels) 140 | bit 23: Polarity (0 for OFF, 1 for ON) 141 | bit 22 - 0: Timestamp (in microseconds) 142 | ''' 143 | with open(file_name, 'rb') as bin_f: 144 | # `& 128` 是取一个8位二进制数的最高位 145 | # `& 127` 是取其除了最高位,也就是剩下的7位 146 | raw_data = np.uint32(np.fromfile(bin_f, dtype=np.uint8)) 147 | x = raw_data[0::5] 148 | y = raw_data[1::5] 149 | rd_2__5 = raw_data[2::5] 150 | p = (rd_2__5 & 128) >> 7 151 | t = ((rd_2__5 & 127) << 16) | (raw_data[3::5] << 8) | (raw_data[4::5]) 152 | return {'t': t, 'x': x, 'y': y, 'p': p} 153 | 154 | 155 | def load_npz_frames(file_name: str) -> np.ndarray: 156 | ''' 157 | :param file_name: path of the npz file that saves the frames 158 | :type file_name: str 159 | :return: frames 160 | :rtype: np.ndarray 161 | ''' 162 | return np.load(file_name)['frames'] 163 | 164 | 165 | def integrate_events_segment_to_frame(events: Dict, H: int, W: int, j_l: int = 0, j_r: int = -1) -> np.ndarray: 166 | ''' 167 | :param events: a dict whose keys are ['t', 'x', 'y', 'p'] and values are ``numpy.ndarray`` 168 | :type events: Dict 169 | :param H: height of the frame 170 | :type H: int 171 | :param W: weight of the frame 172 | :type W: int 173 | :param j_l: the start index of the integral interval, which is included 174 | :type j_l: int 175 | :param j_r: the right index of the integral interval, which is not included 176 | :type j_r: 177 | :return: frames 178 | :rtype: np.ndarray 179 | 180 | Denote a two channels frame as :math:`F` and a pixel at :math:`(p, x, y)` as :math:`F(p, x, y)`, the pixel value is integrated from the events data whose indices are in :math:`[j_{l}, j_{r})`: 181 | 182 | .. math:: 183 | 184 | F(p, x, y) &= \\sum_{i = j_{l}}^{j_{r} - 1} \\mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i}) 185 | 186 | where :math:`\\lfloor \\cdot \rfloor` is the floor operation, :math:`\\mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i})` is an indicator function and it equals 1 only when :math:`(p, x, y) = (p_{i}, x_{i}, y_{i})`. 187 | ''' 188 | # 累计脉冲需要用bitcount而不能直接相加,原因可参考下面的示例代码,以及 189 | # https://stackoverflow.com/questions/15973827/handling-of-duplicate-indices-in-numpy-assignments 190 | # We must use ``bincount`` rather than simply ``+``. See the following reference: 191 | # https://stackoverflow.com/questions/15973827/handling-of-duplicate-indices-in-numpy-assignments 192 | 193 | # Here is an example: 194 | 195 | # height = 3 196 | # width = 3 197 | # frames = np.zeros(shape=[2, height, width]) 198 | # events = { 199 | # 'x': np.asarray([1, 2, 1, 1]), 200 | # 'y': np.asarray([1, 1, 1, 2]), 201 | # 'p': np.asarray([0, 1, 0, 1]) 202 | # } 203 | # 204 | # frames[0, events['y'], events['x']] += (1 - events['p']) 205 | # frames[1, events['y'], events['x']] += events['p'] 206 | # print('wrong accumulation\n', frames) 207 | # 208 | # frames = np.zeros(shape=[2, height, width]) 209 | # for i in range(events['p'].__len__()): 210 | # frames[events['p'][i], events['y'][i], events['x'][i]] += 1 211 | # print('correct accumulation\n', frames) 212 | # 213 | # frames = np.zeros(shape=[2, height, width]) 214 | # frames = frames.reshape(2, -1) 215 | # 216 | # mask = [events['p'] == 0] 217 | # mask.append(np.logical_not(mask[0])) 218 | # for i in range(2): 219 | # position = events['y'][mask[i]] * width + events['x'][mask[i]] 220 | # events_number_per_pos = np.bincount(position) 221 | # idx = np.arange(events_number_per_pos.size) 222 | # frames[i][idx] += events_number_per_pos 223 | # frames = frames.reshape(2, height, width) 224 | # print('correct accumulation by bincount\n', frames) 225 | 226 | frame = np.zeros(shape=[2, H * W]) 227 | x = events['x'][j_l: j_r].astype(int) # avoid overflow 228 | y = events['y'][j_l: j_r].astype(int) 229 | p = events['p'][j_l: j_r] 230 | mask = [] 231 | mask.append(p == 0) 232 | mask.append(np.logical_not(mask[0])) 233 | for c in range(2): 234 | position = y[mask[c]] * W + x[mask[c]] 235 | events_number_per_pos = np.bincount(position) 236 | frame[c][np.arange(events_number_per_pos.size)] += events_number_per_pos 237 | return frame.reshape((2, H, W)) 238 | 239 | 240 | def cal_fixed_frames_number_segment_index(events_t: np.ndarray, split_by: str, frames_num: int) -> tuple: 241 | ''' 242 | :param events_t: events' t 243 | :type events_t: numpy.ndarray 244 | :param split_by: 'time' or 'number' 245 | :type split_by: str 246 | :param frames_num: the number of frames 247 | :type frames_num: int 248 | :return: a tuple ``(j_l, j_r)`` 249 | :rtype: tuple 250 | 251 | Denote ``frames_num`` as :math:`M`, if ``split_by`` is ``'time'``, then 252 | 253 | .. math:: 254 | 255 | \\Delta T & = [\\frac{t_{N-1} - t_{0}}{M}] \\\\ 256 | j_{l} & = \\mathop{\\arg\\min}\\limits_{k} \\{t_{k} | t_{k} \\geq t_{0} + \\Delta T \\cdot j\\} \\\\ 257 | j_{r} & = \\begin{cases} \\mathop{\\arg\\max}\\limits_{k} \\{t_{k} | t_{k} < t_{0} + \\Delta T \\cdot (j + 1)\\} + 1, & j < M - 1 \\cr N, & j = M - 1 \\end{cases} 258 | 259 | If ``split_by`` is ``'number'``, then 260 | 261 | .. math:: 262 | j_{l} & = [\\frac{N}{M}] \\cdot j \\\\ 263 | j_{r} & = \\begin{cases} [\\frac{N}{M}] \\cdot (j + 1), & j < M - 1 \\cr N, & j = M - 1 \\end{cases} 264 | ''' 265 | j_l = np.zeros(shape=[frames_num], dtype=int) 266 | j_r = np.zeros(shape=[frames_num], dtype=int) 267 | N = events_t.size 268 | 269 | if split_by == 'number': 270 | di = N // frames_num 271 | for i in range(frames_num): 272 | j_l[i] = i * di 273 | j_r[i] = j_l[i] + di 274 | j_r[-1] = N 275 | 276 | elif split_by == 'time': 277 | dt = (events_t[-1] - events_t[0]) // frames_num 278 | idx = np.arange(N) 279 | for i in range(frames_num): 280 | t_l = dt * i + events_t[0] 281 | t_r = t_l + dt 282 | mask = np.logical_and(events_t >= t_l, events_t < t_r) 283 | idx_masked = idx[mask] 284 | j_l[i] = idx_masked[0] 285 | j_r[i] = idx_masked[-1] + 1 286 | 287 | j_r[-1] = N 288 | else: 289 | raise NotImplementedError 290 | 291 | return j_l, j_r 292 | 293 | 294 | def integrate_events_by_fixed_frames_number(events: Dict, split_by: str, frames_num: int, H: int, W: int) -> np.ndarray: 295 | ''' 296 | :param events: a dict whose keys are ['t', 'x', 'y', 'p'] and values are ``numpy.ndarray`` 297 | :type events: Dict 298 | :param split_by: 'time' or 'number' 299 | :type split_by: str 300 | :param frames_num: the number of frames 301 | :type frames_num: int 302 | :param H: the height of frame 303 | :type H: int 304 | :param W: the weight of frame 305 | :type W: int 306 | :return: frames 307 | :rtype: np.ndarray 308 | 309 | Integrate events to frames by fixed frames number. See ``cal_fixed_frames_number_segment_index`` and ``integrate_events_segment_to_frame`` for more details. 310 | ''' 311 | j_l, j_r = cal_fixed_frames_number_segment_index(events['t'], split_by, frames_num) 312 | frames = np.zeros([frames_num, 2, H, W]) 313 | for i in range(frames_num): 314 | frames[i] = integrate_events_segment_to_frame(events, H, W, j_l[i], j_r[i]) 315 | return frames 316 | 317 | 318 | def integrate_events_file_to_frames_file_by_fixed_frames_number(events_np_file: str, output_dir: str, split_by: str, frames_num: int, H: int, W: int, print_save: bool = False) -> None: 319 | ''' 320 | :param events_np_file: path of the events np file 321 | :type events_np_file: str 322 | :param output_dir: output directory for saving the frames 323 | :type output_dir: str 324 | :param split_by: 'time' or 'number' 325 | :type split_by: str 326 | :param frames_num: the number of frames 327 | :type frames_num: int 328 | :param H: the height of frame 329 | :type H: int 330 | :param W: the weight of frame 331 | :type W: int 332 | :param print_save: If ``True``, this function will print saved files' paths. 333 | :type print_save: bool 334 | :return: None 335 | 336 | Integrate a events file to frames by fixed frames number and save it. See ``cal_fixed_frames_number_segment_index`` and ``integrate_events_segment_to_frame`` for more details. 337 | ''' 338 | fname = os.path.join(output_dir, os.path.basename(events_np_file)) 339 | np.savez(fname, frames=integrate_events_by_fixed_frames_number(np.load(events_np_file), split_by, frames_num, H, W)) 340 | if print_save: 341 | print(f'Frames [{fname}] saved.') 342 | 343 | 344 | def integrate_events_by_fixed_duration(events: Dict, duration: int, H: int, W: int) -> np.ndarray: 345 | ''' 346 | :param events: a dict whose keys are ['t', 'x', 'y', 'p'] and values are ``numpy.ndarray`` 347 | :type events: Dict 348 | :param duration: the time duration of each frame 349 | :type duration: int 350 | :param H: the height of frame 351 | :type H: int 352 | :param W: the weight of frame 353 | :type W: int 354 | :return: frames 355 | :rtype: np.ndarray 356 | 357 | Integrate events to frames by fixed time duration of each frame. 358 | ''' 359 | t = events['t'] 360 | N = t.size 361 | 362 | frames = [] 363 | left = 0 364 | right = 0 365 | while True: 366 | t_l = t[left] 367 | while True: 368 | if right == N or t[right] - t_l > duration: 369 | break 370 | else: 371 | right += 1 372 | # integrate from index [left, right) 373 | frames.append(np.expand_dims(integrate_events_segment_to_frame(events, H, W, left, right), 0)) 374 | 375 | left = right 376 | 377 | if right == N: 378 | return np.concatenate(frames) 379 | 380 | 381 | def integrate_events_file_to_frames_file_by_fixed_duration(events_np_file: str, output_dir: str, duration: int, H: int, W: int, print_save: bool = False) -> None: 382 | ''' 383 | :param events_np_file: path of the events np file 384 | :type events_np_file: str 385 | :param output_dir: output directory for saving the frames 386 | :type output_dir: str 387 | :param duration: the time duration of each frame 388 | :type duration: int 389 | :param H: the height of frame 390 | :type H: int 391 | :param W: the weight of frame 392 | :type W: int 393 | :param print_save: If ``True``, this function will print saved files' paths. 394 | :type print_save: bool 395 | :return: None 396 | 397 | Integrate events to frames by fixed time duration of each frame. 398 | ''' 399 | frames = integrate_events_by_fixed_duration(np.load(events_np_file), duration, H, W) 400 | fname, _ = os.path.splitext(os.path.basename(events_np_file)) 401 | fname = os.path.join(output_dir, f'{fname}_{frames.shape[0]}.npz') 402 | np.savez(fname, frames=frames) 403 | if print_save: 404 | print(f'Frames [{fname}] saved.') 405 | return frames.shape[0] 406 | 407 | 408 | def create_same_directory_structure(source_dir: str, target_dir: str) -> None: 409 | ''' 410 | :param source_dir: Path of the directory that be copied from 411 | :type source_dir: str 412 | :param target_dir: Path of the directory that be copied to 413 | :type target_dir: str 414 | :return: None 415 | 416 | Create the same directory structure in ``target_dir`` with that of ``source_dir``. 417 | ''' 418 | for sub_dir_name in os.listdir(source_dir): 419 | source_sub_dir = os.path.join(source_dir, sub_dir_name) 420 | if os.path.isdir(source_sub_dir): 421 | target_sub_dir = os.path.join(target_dir, sub_dir_name) 422 | os.mkdir(target_sub_dir) 423 | print(f'Mkdir [{target_sub_dir}].') 424 | create_same_directory_structure(source_sub_dir, target_sub_dir) 425 | 426 | 427 | def split_to_train_test_set(train_ratio: float, origin_dataset: torch.utils.data.Dataset, num_classes: int, random_split: bool = False): 428 | ''' 429 | :param train_ratio: split the ratio of the origin dataset as the train set 430 | :type train_ratio: float 431 | :param origin_dataset: the origin dataset 432 | :type origin_dataset: torch.utils.data.Dataset 433 | :param num_classes: total classes number, e.g., ``10`` for the MNIST dataset 434 | :type num_classes: int 435 | :param random_split: If ``False``, the front ratio of samples in each classes will 436 | be included in train set, while the reset will be included in test set. 437 | If ``True``, this function will split samples in each classes randomly. The randomness is controlled by 438 | ``numpy.randon.seed`` 439 | :type random_split: int 440 | :return: a tuple ``(train_set, test_set)`` 441 | :rtype: tuple 442 | ''' 443 | label_idx = [] 444 | for i in range(num_classes): 445 | label_idx.append([]) 446 | 447 | for i, x, y in enumerate(origin_dataset): 448 | if isinstance(y, np.ndarray) or isinstance(y, torch.Tensor): 449 | y = y.item() 450 | label_idx[y].append(i) 451 | train_idx = [] 452 | test_idx = [] 453 | if random_split: 454 | for i in range(num_classes): 455 | np.random.shuffle(label_idx[i]) 456 | 457 | for i in range(num_classes): 458 | pos = math.ceil(label_idx[i].__len__() * train_ratio) 459 | train_idx.extend(label_idx[i][0: pos]) 460 | test_idx.extend(label_idx[i][pos: label_idx[i].__len__()]) 461 | 462 | return torch.utils.data.Subset(origin_dataset, train_idx), torch.utils.data.Subset(origin_dataset, test_idx) 463 | 464 | 465 | def pad_sequence_collate(batch: list): 466 | ''' 467 | :param batch: a list of samples that contains ``(x, y)``, where ``x.shape=[T, *]`` and ``y`` is the label 468 | :type batch: list 469 | :return: batched samples, where ``x`` is padded with the same length 470 | :rtype: tuple 471 | 472 | This function can be use as the ``collate_fn`` for ``DataLoader`` to process the dataset with variable length, e.g., a ``NeuromorphicDatasetFolder`` with fixed duration to integrate events to frames. 473 | 474 | Here is an example: 475 | 476 | .. code-block:: python 477 | 478 | class RandomLengthDataset(torch.utils.data.Dataset): 479 | def __init__(self, n=1000): 480 | super().__init__() 481 | self.n = n 482 | 483 | def __getitem__(self, i): 484 | return torch.rand([random.randint(1, 10), 28, 28]), random.randint(0, 10) 485 | 486 | def __len__(self): 487 | return self.n 488 | 489 | loader = torch.utils.data.DataLoader(RandomLengthDataset(n=32), batch_size=16, collate_fn=pad_sequence_collate) 490 | 491 | for x, y, z in loader: 492 | print(x.shape, y.shape, z) 493 | 494 | And the outputs are: 495 | 496 | .. code-block:: bash 497 | 498 | torch.Size([10, 16, 28, 28]) torch.Size([16]) tensor([ 1, 9, 3, 4, 1, 2, 9, 7, 2, 1, 5, 7, 4, 10, 9, 5]) 499 | torch.Size([10, 16, 28, 28]) torch.Size([16]) tensor([ 1, 8, 7, 10, 3, 10, 6, 7, 5, 9, 10, 5, 9, 6, 7, 6]) 500 | 501 | ''' 502 | x_list = [] 503 | x_len_list = [] 504 | y_list = [] 505 | for x, y in batch: 506 | x_list.append(torch.as_tensor(x)) 507 | x_len_list.append(x.shape[0]) 508 | y_list.append(y) 509 | 510 | return torch.nn.utils.rnn.pad_sequence(x_list, batch_first=True), torch.as_tensor(y_list), torch.as_tensor(x_len_list) 511 | 512 | 513 | def padded_sequence_mask(sequence_len: torch.Tensor, T=None): 514 | ''' 515 | :param sequence_len: a tensor ``shape = [N]`` that contains sequences lengths of each batch element 516 | :type sequence_len: torch.Tensor 517 | :param T: The maximum length of sequences. If ``None``, the maximum element in ``sequence_len`` will be seen as ``T`` 518 | :type T: int 519 | :return: a bool mask with shape = [T, N], where the padded position is ``False`` 520 | :rtype: torch.Tensor 521 | 522 | Here is an example: 523 | 524 | .. code-block:: python 525 | 526 | x1 = torch.rand([2, 6]) 527 | x2 = torch.rand([3, 6]) 528 | x3 = torch.rand([4, 6]) 529 | x = torch.nn.utils.rnn.pad_sequence([x1, x2, x3]) # [T, N, *] 530 | print('x.shape=', x.shape) 531 | x_len = torch.as_tensor([x1.shape[0], x2.shape[0], x3.shape[0]]) 532 | mask = padded_sequence_mask(x_len) 533 | print('mask.shape=', mask.shape) 534 | print('mask=\n', mask) 535 | 536 | And the outputs are: 537 | 538 | .. code-block:: bash 539 | 540 | x.shape= torch.Size([4, 3, 6]) 541 | mask.shape= torch.Size([4, 3]) 542 | mask= 543 | tensor([[ True, True, True], 544 | [ True, True, True], 545 | [False, True, True], 546 | [False, False, True]]) 547 | 548 | ''' 549 | if T is None: 550 | T = sequence_len.max().item() 551 | N = sequence_len.numel() 552 | t_seq = torch.arange(0, T).unsqueeze(1).repeat(1, N).to(sequence_len) # [T, N] 553 | return t_seq < sequence_len.unsqueeze(0).repeat(T, 1) 554 | 555 | 556 | class NeuromorphicDatasetFolder(DatasetFolder): 557 | def __init__( 558 | self, 559 | root: str, 560 | train: bool = None, 561 | data_type: str = 'event', 562 | frames_number: int = None, 563 | split_by: str = None, 564 | duration: int = None, 565 | transform: Optional[Callable] = None, 566 | target_transform: Optional[Callable] = None, 567 | ) -> None: 568 | ''' 569 | :param root: root path of the dataset 570 | :type root: str 571 | :param train: whether use the train set. Set ``True`` or ``False`` for those datasets provide train/test 572 | division, e.g., DVS128 Gesture dataset. If the dataset does not provide train/test division, e.g., CIFAR10-DVS, 573 | please set ``None`` and use :class:`~split_to_train_test_set` function to get train/test set 574 | :type train: bool 575 | :param data_type: `event` or `frame` 576 | :type data_type: str 577 | :param frames_number: the integrated frame number 578 | :type frames_number: int 579 | :param split_by: `time` or `number` 580 | :type split_by: str 581 | :param duration: the time duration of each frame 582 | :type duration: int 583 | :param transform: a function/transform that takes in 584 | a sample and returns a transformed version. 585 | E.g, ``transforms.RandomCrop`` for images. 586 | :type transform: callable 587 | :param target_transform: a function/transform that takes 588 | in the target and transforms it. 589 | :type target_transform: callable 590 | 591 | The code class for neuromorphic dataset. Users can define a new dataset by inheriting this class and implementing 592 | all abstract methods. Users can refer to :class:`spikingjelly.datasets.dvs128_gesture.DVS128Gesture`. 593 | 594 | If ``data_type == 'event'`` 595 | the sample in this dataset is a dict whose keys are ['t', 'x', 'y', 'p'] and values are ``numpy.ndarray``. 596 | 597 | If ``data_type == 'frame'`` and ``frames_number`` is not ``None`` 598 | events will be integrated to frames with fixed frames number. ``split_by`` will define how to split events. 599 | See :class:`cal_fixed_frames_number_segment_index` for 600 | more details. 601 | 602 | If ``data_type == 'frame'``, ``frames_number`` is ``None``, and ``duration`` is not ``None`` 603 | events will be integrated to frames with fixed time duration. 604 | 605 | ''' 606 | 607 | events_np_root = os.path.join(root, 'events_np') 608 | 609 | if not os.path.exists(events_np_root): 610 | 611 | download_root = os.path.join(root, 'download') 612 | 613 | if os.path.exists(download_root): 614 | print(f'The [{download_root}] directory for saving downloaded files already exists, check files...') 615 | # check files 616 | resource_list = self.resource_url_md5() 617 | for i in range(resource_list.__len__()): 618 | file_name, url, md5 = resource_list[i] 619 | fpath = os.path.join(download_root, file_name) 620 | if not utils.check_integrity(fpath=fpath, md5=md5): 621 | print(f'The file [{fpath}] does not exist or is corrupted.') 622 | 623 | if os.path.exists(fpath): 624 | # If file is corrupted, we will remove it. 625 | os.remove(fpath) 626 | print(f'Remove [{fpath}]') 627 | 628 | if self.downloadable(): 629 | # If file does not exist, we will download it. 630 | print(f'Download [{file_name}] from [{url}] to [{download_root}]') 631 | utils.download_url(url=url, root=download_root, filename=file_name, md5=md5) 632 | else: 633 | raise NotImplementedError( 634 | f'This dataset can not be downloaded by SpikingJelly, please download [{file_name}] from [{url}] manually and put files at {download_root}.') 635 | 636 | else: 637 | os.mkdir(download_root) 638 | print(f'Mkdir [{download_root}] to save downloaded files.') 639 | resource_list = self.resource_url_md5() 640 | if self.downloadable(): 641 | # download and extract file 642 | for i in range(resource_list.__len__()): 643 | file_name, url, md5 = resource_list[i] 644 | print(f'Download [{file_name}] from [{url}] to [{download_root}]') 645 | utils.download_url(url=url, root=download_root, filename=file_name, md5=md5) 646 | else: 647 | raise NotImplementedError(f'This dataset can not be downloaded by SpikingJelly, ' 648 | f'please download files manually and put files at [{download_root}]. ' 649 | f'The resources file_name, url, and md5 are: \n{resource_list}') 650 | 651 | # We have downloaded files and checked files. Now, let us extract the files 652 | extract_root = os.path.join(root, 'extract') 653 | if os.path.exists(extract_root): 654 | print(f'The directory [{extract_root}] for saving extracted files already exists.\n' 655 | f'SpikingJelly will not check the data integrity of extracted files.\n' 656 | f'If extracted files are not integrated, please delete [{extract_root}] manually, ' 657 | f'then SpikingJelly will re-extract files from [{download_root}].') 658 | # shutil.rmtree(extract_root) 659 | # print(f'Delete [{extract_root}].') 660 | else: 661 | os.mkdir(extract_root) 662 | print(f'Mkdir [{extract_root}].') 663 | self.extract_downloaded_files(download_root, extract_root) 664 | 665 | # Now let us convert the origin binary files to npz files 666 | os.mkdir(events_np_root) 667 | print(f'Mkdir [{events_np_root}].') 668 | print(f'Start to convert the origin data from [{extract_root}] to [{events_np_root}] in np.ndarray format.') 669 | self.create_events_np_files(extract_root, events_np_root) 670 | 671 | H, W = self.get_H_W() 672 | 673 | if data_type == 'event': 674 | _root = events_np_root 675 | _loader = np.load 676 | _transform = transform 677 | _target_transform = target_transform 678 | 679 | elif data_type == 'frame': 680 | if frames_number is not None: 681 | assert frames_number > 0 and isinstance(frames_number, int) 682 | assert split_by == 'time' or split_by == 'number' 683 | frames_np_root = os.path.join(root, f'frames_number_{frames_number}_split_by_{split_by}') 684 | if os.path.exists(frames_np_root): 685 | print(f'The directory [{frames_np_root}] already exists.') 686 | else: 687 | os.mkdir(frames_np_root) 688 | print(f'Mkdir [{frames_np_root}].') 689 | 690 | # create the same directory structure 691 | create_same_directory_structure(events_np_root, frames_np_root) 692 | 693 | # use multi-thread to accelerate 694 | t_ckp = time.time() 695 | with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 64)) as tpe: 696 | print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].') 697 | for e_root, e_dirs, e_files in os.walk(events_np_root): 698 | if e_files.__len__() > 0: 699 | output_dir = os.path.join(frames_np_root, os.path.relpath(e_root, events_np_root)) 700 | for e_file in e_files: 701 | events_np_file = os.path.join(e_root, e_file) 702 | print(f'Start to integrate [{events_np_file}] to frames and save to [{output_dir}].') 703 | tpe.submit(integrate_events_file_to_frames_file_by_fixed_frames_number, events_np_file, output_dir, split_by, frames_number, H, W, True) 704 | 705 | print(f'Used time = [{round(time.time() - t_ckp, 2)}s].') 706 | 707 | _root = frames_np_root 708 | _loader = load_npz_frames 709 | _transform = transform 710 | _target_transform = target_transform 711 | 712 | elif duration is not None: 713 | assert duration > 0 and isinstance(duration, int) 714 | frames_np_root = os.path.join(root, f'duration_{duration}') 715 | if os.path.exists(frames_np_root): 716 | print(f'The directory [{frames_np_root}] already exists.') 717 | 718 | else: 719 | os.mkdir(frames_np_root) 720 | print(f'Mkdir [{frames_np_root}].') 721 | # create the same directory structure 722 | create_same_directory_structure(events_np_root, frames_np_root) 723 | # use multi-thread to accelerate 724 | t_ckp = time.time() 725 | with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 64)) as tpe: 726 | print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].') 727 | for e_root, e_dirs, e_files in os.walk(events_np_root): 728 | if e_files.__len__() > 0: 729 | output_dir = os.path.join(frames_np_root, os.path.relpath(e_root, events_np_root)) 730 | for e_file in e_files: 731 | events_np_file = os.path.join(e_root, e_file) 732 | print(f'Start to integrate [{events_np_file}] to frames and save to [{output_dir}].') 733 | tpe.submit(integrate_events_file_to_frames_file_by_fixed_duration, events_np_file, output_dir, duration, H, W, True) 734 | 735 | print(f'Used time = [{round(time.time() - t_ckp, 2)}s].') 736 | 737 | _root = frames_np_root 738 | _loader = load_npz_frames 739 | _transform = transform 740 | _target_transform = target_transform 741 | 742 | else: 743 | raise ValueError('frames_number and duration can not both be None.') 744 | 745 | if train is not None: 746 | if train: 747 | _root = os.path.join(_root, 'train') 748 | else: 749 | _root = os.path.join(_root, 'test') 750 | 751 | super().__init__(root=_root, loader=_loader, extensions='.npz', transform=_transform, 752 | target_transform=_target_transform) 753 | 754 | @staticmethod 755 | @abstractmethod 756 | def load_origin_data(file_name: str) -> Dict: 757 | ''' 758 | :param file_name: path of the events file 759 | :type file_name: str 760 | :return: a dict whose keys are ['t', 'x', 'y', 'p'] and values are ``numpy.ndarray`` 761 | :rtype: Dict 762 | 763 | This function defines how to read the origin binary data. 764 | ''' 765 | pass 766 | 767 | @staticmethod 768 | @abstractmethod 769 | def resource_url_md5() -> list: 770 | ''' 771 | :return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5 772 | :rtype: list 773 | ''' 774 | pass 775 | 776 | @staticmethod 777 | @abstractmethod 778 | def downloadable() -> bool: 779 | ''' 780 | :return: Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually 781 | :rtype: bool 782 | ''' 783 | pass 784 | 785 | @staticmethod 786 | @abstractmethod 787 | def extract_downloaded_files(download_root: str, extract_root: str): 788 | ''' 789 | :param download_root: Root directory path which saves downloaded dataset files 790 | :type download_root: str 791 | :param extract_root: Root directory path which saves extracted files from downloaded files 792 | :type extract_root: str 793 | :return: None 794 | 795 | This function defines how to extract download files. 796 | ''' 797 | pass 798 | 799 | @staticmethod 800 | @abstractmethod 801 | def create_events_np_files(extract_root: str, events_np_root: str): 802 | ''' 803 | :param extract_root: Root directory path which saves extracted files from downloaded files 804 | :type extract_root: str 805 | :param events_np_root: Root directory path which saves events files in the ``npz`` format 806 | :type events_np_root: 807 | :return: None 808 | 809 | This function defines how to convert the origin binary data in ``extract_root`` to ``npz`` format and save converted files in ``events_np_root``. 810 | ''' 811 | pass 812 | 813 | @staticmethod 814 | @abstractmethod 815 | def get_H_W() -> Tuple: 816 | ''' 817 | :return: A tuple ``(H, W)``, where ``H`` is the height of the data and ``W` is the weight of the data. 818 | For example, this function returns ``(128, 128)`` for the DVS128 Gesture dataset. 819 | :rtype: tuple 820 | ''' 821 | pass 822 | -------------------------------------------------------------------------------- /code/datasets/dvs128_gesture.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple 2 | import numpy as np 3 | from datasets.common import * 4 | from torchvision.datasets.utils import extract_archive 5 | import os 6 | import multiprocessing 7 | from concurrent.futures import ThreadPoolExecutor 8 | import time 9 | 10 | 11 | class DVS128Gesture(NeuromorphicDatasetFolder): 12 | def __init__( 13 | self, 14 | root: str, 15 | train: bool = None, 16 | data_type: str = 'event', 17 | frames_number: int = None, 18 | split_by: str = None, 19 | duration: int = None, 20 | transform: Optional[Callable] = None, 21 | target_transform: Optional[Callable] = None, 22 | ) -> None: 23 | ''' 24 | :param root: root path of the dataset 25 | :type root: str 26 | :param train: whether use the train set 27 | :type train: bool 28 | :param data_type: `event` or `frame` 29 | :type data_type: str 30 | :param frames_number: the integrated frame number 31 | :type frames_number: int 32 | :param split_by: `time` or `number` 33 | :type split_by: str 34 | :param duration: the time duration of each frame 35 | :type duration: int 36 | :param transform: a function/transform that takes in 37 | a sample and returns a transformed version. 38 | E.g, ``transforms.RandomCrop`` for images. 39 | :type transform: callable 40 | :param target_transform: a function/transform that takes 41 | in the target and transforms it. 42 | :type target_transform: callable 43 | 44 | If ``data_type == 'event'`` 45 | the sample in this dataset is a dict whose keys are ['t', 'x', 'y', 'p'] and values are ``numpy.ndarray``. 46 | 47 | If ``data_type == 'frame'`` and ``frames_number`` is not ``None`` 48 | events will be integrated to frames with fixed frames number. ``split_by`` will define how to split events. 49 | See :class:`cal_fixed_frames_number_segment_index` for 50 | more details. 51 | 52 | If ``data_type == 'frame'``, ``frames_number`` is ``None``, and ``duration`` is not ``None`` 53 | events will be integrated to frames with fixed time duration. 54 | 55 | ''' 56 | assert train is not None 57 | super().__init__(root, train, data_type, frames_number, split_by, duration, transform, target_transform) 58 | 59 | @staticmethod 60 | def resource_url_md5() -> list: 61 | ''' 62 | :return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5 63 | :rtype: list 64 | ''' 65 | url = 'https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/folder/50167556794' 66 | return [ 67 | ('DvsGesture.tar.gz', url, '8a5c71fb11e24e5ca5b11866ca6c00a1'), 68 | ('gesture_mapping.csv', url, '109b2ae64a0e1f3ef535b18ad7367fd1'), 69 | ('LICENSE.txt', url, '065e10099753156f18f51941e6e44b66'), 70 | ('README.txt', url, 'a0663d3b1d8307c329a43d949ee32d19') 71 | ] 72 | 73 | @staticmethod 74 | def downloadable() -> bool: 75 | ''' 76 | :return: Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually 77 | :rtype: bool 78 | ''' 79 | return False 80 | 81 | @staticmethod 82 | def extract_downloaded_files(download_root: str, extract_root: str): 83 | ''' 84 | :param download_root: Root directory path which saves downloaded dataset files 85 | :type download_root: str 86 | :param extract_root: Root directory path which saves extracted files from downloaded files 87 | :type extract_root: str 88 | :return: None 89 | 90 | This function defines how to extract download files. 91 | ''' 92 | fpath = os.path.join(download_root, 'DvsGesture.tar.gz') 93 | print(f'Extract [{fpath}] to [{extract_root}].') 94 | extract_archive(fpath, extract_root) 95 | 96 | @staticmethod 97 | def load_origin_data(file_name: str) -> Dict: 98 | ''' 99 | :param file_name: path of the events file 100 | :type file_name: str 101 | :return: a dict whose keys are ['t', 'x', 'y', 'p'] and values are ``numpy.ndarray`` 102 | :rtype: Dict 103 | 104 | This function defines how to read the origin binary data. 105 | ''' 106 | return load_aedat_v3(file_name) 107 | 108 | @staticmethod 109 | def split_aedat_files_to_np(fname: str, aedat_file: str, csv_file: str, output_dir: str): 110 | events = DVS128Gesture.load_origin_data(aedat_file) 111 | print(f'Start to split [{aedat_file}] to samples.') 112 | # read csv file and get time stamp and label of each sample 113 | # then split the origin data to samples 114 | csv_data = np.loadtxt(csv_file, dtype=np.uint32, delimiter=',', skiprows=1) 115 | 116 | # Note that there are some files that many samples have the same label, e.g., user26_fluorescent_labels.csv 117 | label_file_num = [0] * 11 118 | 119 | # There are some wrong time stamp in this dataset, e.g., in user22_led_labels.csv, ``endTime_usec`` of the class 9 is 120 | # larger than ``startTime_usec`` of the class 10. So, the following codes, which are used in old version of SpikingJelly, 121 | # are replaced by new codes. 122 | 123 | for i in range(csv_data.shape[0]): 124 | # the label of DVS128 Gesture is 1, 2, ..., 11. We set 0 as the first label, rather than 1 125 | label = csv_data[i][0] - 1 126 | t_start = csv_data[i][1] 127 | t_end = csv_data[i][2] 128 | mask = np.logical_and(events['t'] >= t_start, events['t'] < t_end) 129 | file_name = os.path.join(output_dir, str(label), f'{fname}_{label_file_num[label]}.npz') 130 | np.savez(file_name, 131 | t=events['t'][mask], 132 | x=events['x'][mask], 133 | y=events['y'][mask], 134 | p=events['p'][mask] 135 | ) 136 | print(f'[{file_name}] saved.') 137 | label_file_num[label] += 1 138 | 139 | # old codes: 140 | 141 | # index = 0 142 | # index_l = 0 143 | # index_r = 0 144 | # for i in range(csv_data.shape[0]): 145 | # # the label of DVS128 Gesture is 1, 2, ..., 11. We set 0 as the first label, rather than 1 146 | # label = csv_data[i][0] - 1 147 | # t_start = csv_data[i][1] 148 | # t_end = csv_data[i][2] 149 | # 150 | # while True: 151 | # t = events['t'][index] 152 | # if t < t_start: 153 | # index += 1 154 | # else: 155 | # index_l = index 156 | # break 157 | # while True: 158 | # t = events['t'][index] 159 | # if t < t_end: 160 | # index += 1 161 | # else: 162 | # index_r = index 163 | # break 164 | # 165 | # file_name = os.path.join(output_dir, str(label), f'{fname}_{label_file_num[label]}.npz') 166 | # np.savez(file_name, 167 | # t=events['t'][index_l:index_r], 168 | # x=events['x'][index_l:index_r], 169 | # y=events['y'][index_l:index_r], 170 | # p=events['p'][index_l:index_r] 171 | # ) 172 | # print(f'[{file_name}] saved.') 173 | # label_file_num[label] += 1 174 | 175 | @staticmethod 176 | def create_events_np_files(extract_root: str, events_np_root: str): 177 | ''' 178 | :param extract_root: Root directory path which saves extracted files from downloaded files 179 | :type extract_root: str 180 | :param events_np_root: Root directory path which saves events files in the ``npz`` format 181 | :type events_np_root: 182 | :return: None 183 | 184 | This function defines how to convert the origin binary data in ``extract_root`` to ``npz`` format and save converted files in ``events_np_root``. 185 | ''' 186 | aedat_dir = os.path.join(extract_root, 'DvsGesture') 187 | train_dir = os.path.join(events_np_root, 'train') 188 | test_dir = os.path.join(events_np_root, 'test') 189 | os.mkdir(train_dir) 190 | os.mkdir(test_dir) 191 | print(f'Mkdir [{train_dir, test_dir}.') 192 | for label in range(11): 193 | os.mkdir(os.path.join(train_dir, str(label))) 194 | os.mkdir(os.path.join(test_dir, str(label))) 195 | print(f'Mkdir {os.listdir(train_dir)} in [{train_dir}] and {os.listdir(test_dir)} in [{test_dir}].') 196 | 197 | with open(os.path.join(aedat_dir, 'trials_to_train.txt')) as trials_to_train_txt, open( 198 | os.path.join(aedat_dir, 'trials_to_test.txt')) as trials_to_test_txt: 199 | # use multi-thread to accelerate 200 | t_ckp = time.time() 201 | with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 64)) as tpe: 202 | print(f'Start the ThreadPoolExecutor with max workers = [{tpe._max_workers}].') 203 | 204 | for fname in trials_to_train_txt.readlines(): 205 | fname = fname.strip() 206 | if fname.__len__() > 0: 207 | aedat_file = os.path.join(aedat_dir, fname) 208 | fname = os.path.splitext(fname)[0] 209 | tpe.submit(DVS128Gesture.split_aedat_files_to_np, fname, aedat_file, os.path.join(aedat_dir, fname + '_labels.csv'), train_dir) 210 | 211 | for fname in trials_to_test_txt.readlines(): 212 | fname = fname.strip() 213 | if fname.__len__() > 0: 214 | aedat_file = os.path.join(aedat_dir, fname) 215 | fname = os.path.splitext(fname)[0] 216 | tpe.submit(DVS128Gesture.split_aedat_files_to_np, fname, aedat_file, 217 | os.path.join(aedat_dir, fname + '_labels.csv'), test_dir) 218 | 219 | print(f'Used time = [{round(time.time() - t_ckp, 2)}s].') 220 | print(f'All aedat files have been split to samples and saved into [{train_dir, test_dir}].') 221 | 222 | @staticmethod 223 | def get_H_W() -> Tuple: 224 | ''' 225 | :return: A tuple ``(H, W)``, where ``H`` is the height of the data and ``W` is the weight of the data. 226 | For example, this function returns ``(128, 128)`` for the DVS128 Gesture dataset. 227 | :rtype: tuple 228 | ''' 229 | return 128, 128 230 | -------------------------------------------------------------------------------- /code/datasets/n_mnist.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple 2 | import numpy as np 3 | from datasets.common import * 4 | from torchvision.datasets.utils import extract_archive 5 | import os 6 | import multiprocessing 7 | from concurrent.futures import ThreadPoolExecutor 8 | import time 9 | 10 | 11 | class NMNIST(NeuromorphicDatasetFolder): 12 | def __init__( 13 | self, 14 | root: str, 15 | train: bool = None, 16 | data_type: str = 'event', 17 | frames_number: int = None, 18 | split_by: str = None, 19 | duration: int = None, 20 | transform: Optional[Callable] = None, 21 | target_transform: Optional[Callable] = None, 22 | ) -> None: 23 | ''' 24 | :param root: root path of the dataset 25 | :type root: str 26 | :param train: whether use the train set 27 | :type train: bool 28 | :param data_type: `event` or `frame` 29 | :type data_type: str 30 | :param frames_number: the integrated frame number 31 | :type frames_number: int 32 | :param split_by: `time` or `number` 33 | :type split_by: str 34 | :param duration: the time duration of each frame 35 | :type duration: int 36 | :param transform: a function/transform that takes in 37 | a sample and returns a transformed version. 38 | E.g, ``transforms.RandomCrop`` for images. 39 | :type transform: callable 40 | :param target_transform: a function/transform that takes 41 | in the target and transforms it. 42 | :type target_transform: callable 43 | 44 | If ``data_type == 'event'`` 45 | the sample in this dataset is a dict whose keys are ['t', 'x', 'y', 'p'] and values are ``numpy.ndarray``. 46 | 47 | If ``data_type == 'frame'`` and ``frames_number`` is not ``None`` 48 | events will be integrated to frames with fixed frames number. ``split_by`` will define how to split events. 49 | See :class:`cal_fixed_frames_number_segment_index` for 50 | more details. 51 | 52 | If ``data_type == 'frame'``, ``frames_number`` is ``None``, and ``duration`` is not ``None`` 53 | events will be integrated to frames with fixed time duration. 54 | 55 | ''' 56 | assert train is not None 57 | super().__init__(root, train, data_type, frames_number, split_by, duration, transform, target_transform) 58 | @staticmethod 59 | def resource_url_md5() -> list: 60 | ''' 61 | :return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5 62 | :rtype: list 63 | ''' 64 | url = 'https://www.garrickorchard.com/datasets/n-mnist' 65 | return [ 66 | ('Train.zip', url, '20959b8e626244a1b502305a9e6e2031'), 67 | ('Test.zip', url, '69ca8762b2fe404d9b9bad1103e97832') 68 | ] 69 | 70 | @staticmethod 71 | def downloadable() -> bool: 72 | ''' 73 | :return: Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually 74 | :rtype: bool 75 | ''' 76 | return False 77 | 78 | @staticmethod 79 | def extract_downloaded_files(download_root: str, extract_root: str): 80 | ''' 81 | :param download_root: Root directory path which saves downloaded dataset files 82 | :type download_root: str 83 | :param extract_root: Root directory path which saves extracted files from downloaded files 84 | :type extract_root: str 85 | :return: None 86 | 87 | This function defines how to extract download files. 88 | ''' 89 | with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 2)) as tpe: 90 | for zip_file in os.listdir(download_root): 91 | zip_file = os.path.join(download_root, zip_file) 92 | print(f'Extract [{zip_file}] to [{extract_root}].') 93 | tpe.submit(extract_archive, zip_file, extract_root) 94 | 95 | 96 | @staticmethod 97 | def load_origin_data(file_name: str) -> Dict: 98 | ''' 99 | :param file_name: path of the events file 100 | :type file_name: str 101 | :return: a dict whose keys are ['t', 'x', 'y', 'p'] and values are ``numpy.ndarray`` 102 | :rtype: Dict 103 | 104 | This function defines how to read the origin binary data. 105 | ''' 106 | 107 | return load_ATIS_bin(file_name) 108 | 109 | @staticmethod 110 | def get_H_W() -> Tuple: 111 | ''' 112 | :return: A tuple ``(H, W)``, where ``H`` is the height of the data and ``W` is the weight of the data. 113 | For example, this function returns ``(128, 128)`` for the DVS128 Gesture dataset. 114 | :rtype: tuple 115 | ''' 116 | return 34, 34 117 | 118 | @staticmethod 119 | def read_bin_save_to_np(bin_file: str, np_file: str): 120 | events = NMNIST.load_origin_data(bin_file) 121 | np.savez(np_file, 122 | t=events['t'], 123 | x=events['x'], 124 | y=events['y'], 125 | p=events['p'] 126 | ) 127 | print(f'Save [{bin_file}] to [{np_file}].') 128 | 129 | 130 | @staticmethod 131 | def create_events_np_files(extract_root: str, events_np_root: str): 132 | ''' 133 | :param extract_root: Root directory path which saves extracted files from downloaded files 134 | :type extract_root: str 135 | :param events_np_root: Root directory path which saves events files in the ``npz`` format 136 | :type events_np_root: 137 | :return: None 138 | 139 | This function defines how to convert the origin binary data in ``extract_root`` to ``npz`` format and save converted files in ``events_np_root``. 140 | ''' 141 | t_ckp = time.time() 142 | with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 8)) as tpe: 143 | # too many threads will make the disk overload 144 | for train_test_dir in ['Train', 'Test']: 145 | source_dir = os.path.join(extract_root, train_test_dir) 146 | target_dir = os.path.join(events_np_root, train_test_dir.lower()) 147 | os.mkdir(target_dir) 148 | print(f'Mkdir [{target_dir}].') 149 | for class_name in os.listdir(source_dir): 150 | bin_dir = os.path.join(source_dir, class_name) 151 | np_dir = os.path.join(target_dir, class_name) 152 | os.mkdir(np_dir) 153 | print(f'Mkdir [{np_dir}].') 154 | for bin_file in os.listdir(bin_dir): 155 | source_file = os.path.join(bin_dir, bin_file) 156 | target_file = os.path.join(np_dir, os.path.splitext(bin_file)[0] + '.npz') 157 | print(f'Start to convert [{source_file}] to [{target_file}].') 158 | tpe.submit(NMNIST.read_bin_save_to_np, source_file, 159 | target_file) 160 | 161 | 162 | print(f'Used time = [{round(time.time() - t_ckp, 2)}s].') 163 | -------------------------------------------------------------------------------- /code/encoders.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from torch import einsum 7 | from einops import rearrange, repeat 8 | 9 | from common import ActFun 10 | 11 | 12 | class AutoEncoder(nn.Module): 13 | def __init__(self, step, spike_output=True): 14 | super(AutoEncoder, self).__init__() 15 | self.step = step 16 | self.spike_output = spike_output 17 | 18 | # self.gru = nn.GRU(input_size=1, hidden_size=1, num_layers=3) 19 | self.sigmoid = nn.Sigmoid() 20 | self.fc1 = nn.Linear(1, self.step) 21 | self.fc2 = nn.Linear(self.step, self.step) 22 | self.relu = nn.ReLU() 23 | # 24 | self.act_fun = ActFun.apply 25 | 26 | def forward(self, x): 27 | shape = x.shape 28 | 29 | x = self.fc1(x.view(-1, 1)) 30 | x = self.relu(x) 31 | x = self.fc2(x).transpose_(1, 0) 32 | 33 | # x = x.view(1, -1, 1).repeat(self.step, 1, 1) 34 | # x, _ = self.gru(x) 35 | 36 | x = self.sigmoid(x) 37 | if not self.spike_output: 38 | return x.view(self.step, *shape) 39 | else: 40 | return self.act_fun(x).view(self.step, *shape) 41 | 42 | 43 | # class TransEncoder(nn.Module): 44 | # def __init__(self, step): 45 | # super(TransEncoder, self).__init__() 46 | # self.step = step 47 | # self.trans = Transformer(dim=128, depth=3, heads=8, dim_head=, mlp_dim, dropout=0.) 48 | 49 | 50 | class Encoder(nn.Module): 51 | ''' 52 | (step, batch_size, ) 53 | ''' 54 | def __init__(self, step, encode_type='ttfs'): 55 | super(Encoder, self).__init__() 56 | self.step = step 57 | self.fun = getattr(self, encode_type) 58 | self.encode_type = encode_type 59 | if encode_type == 'auto': 60 | self.fun = AutoEncoder(self.step, spike_output=False) 61 | 62 | def forward(self, inputs, deletion_prob=None, shift_var=None): 63 | if self.encode_type == 'auto': 64 | if self.fun.device != inputs.device: 65 | self.fun.to(inputs.device) 66 | 67 | outputs = self.fun(inputs) 68 | if deletion_prob: 69 | outputs = self.delete(outputs, deletion_prob) 70 | if shift_var: 71 | outputs = self.shift(outputs, shift_var) 72 | return outputs 73 | 74 | @torch.no_grad() 75 | def direct(self, inputs): 76 | shape = inputs.shape 77 | outputs = inputs.unsqueeze(0).repeat(self.step, *([1] * len(shape))) 78 | return outputs 79 | 80 | def auto(self, inputs): 81 | shape = inputs.shape 82 | outputs = self.fun(inputs) 83 | print(outputs.shape) 84 | return outputs 85 | 86 | @torch.no_grad() 87 | def ttfs(self, inputs): 88 | # print("ttfs") 89 | shape = (self.step,) + inputs.shape 90 | outputs = torch.zeros(shape, device=self.device) 91 | for i in range(self.step): 92 | mask = (inputs * self.step <= (self.step - i)) & (inputs * self.step > (self.step - i - 1)) 93 | outputs[i, mask] = 1 / (i + 1) 94 | return outputs 95 | 96 | @torch.no_grad() 97 | def rate(self, inputs): 98 | shape = (self.step,) + inputs.shape 99 | return (inputs > torch.rand(shape, device=self.device)).float() 100 | 101 | @torch.no_grad() 102 | def phase(self, inputs): 103 | shape = (self.step,) + inputs.shape 104 | outputs = torch.zeros(shape, device=self.device) 105 | inputs = (inputs * 256).long() 106 | val = 1. 107 | for i in range(self.step): 108 | if i < 8: 109 | mask = (inputs >> (8 - i - 1)) & 1 != 0 110 | outputs[i, mask] = val 111 | val /= 2. 112 | else: 113 | outputs[i] = outputs[i % 8] 114 | return outputs 115 | 116 | @torch.no_grad() 117 | def delete(self, inputs, prob): 118 | mask = (inputs >= 0) & (torch.randn_like(inputs, device=self.device) < prob) 119 | inputs[mask] = 0. 120 | return inputs 121 | 122 | @torch.no_grad() 123 | def shift(self, inputs, var): 124 | outputs = torch.zeros_like(inputs) 125 | for step in range(self.step): 126 | shift = (var * torch.randn(1)).round_() + step 127 | shift.clamp_(min=0, max=self.step - 1) 128 | outputs[step] += inputs[int(shift)] 129 | return outputs 130 | -------------------------------------------------------------------------------- /code/losses.py: -------------------------------------------------------------------------------- 1 | # Thanks to rwightman's timm package 2 | # github.com:rwightman/pytorch-image-models 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class LabelSmoothingBCEWithLogitsLoss(nn.Module): 10 | 11 | def __init__(self, smoothing=0.1): 12 | """ 13 | Constructor for the LabelSmoothing module. 14 | :param smoothing: label smoothing factor 15 | """ 16 | super(LabelSmoothingBCEWithLogitsLoss, self).__init__() 17 | assert smoothing < 1.0 18 | self.smoothing = smoothing 19 | self.confidence = 1. - smoothing 20 | self.BCELoss = nn.BCEWithLogitsLoss() 21 | 22 | def forward(self, x, target): 23 | target = torch.eye(x.shape[-1], device=x.device)[target] 24 | nll = torch.ones_like(x) / x.shape[-1] 25 | return self.BCELoss(x, target) * self.confidence + self.BCELoss(x, nll) * self.smoothing 26 | 27 | 28 | class LabelSmoothingCrossEntropy(nn.Module): 29 | """ 30 | NLL loss with label smoothing. 31 | """ 32 | 33 | def __init__(self, smoothing=0.1): 34 | """ 35 | Constructor for the LabelSmoothing module. 36 | :param smoothing: label smoothing factor 37 | """ 38 | super(LabelSmoothingCrossEntropy, self).__init__() 39 | assert smoothing < 1.0 40 | self.smoothing = smoothing 41 | self.confidence = 1. - smoothing 42 | 43 | def _compute_losses(self, x, target): 44 | log_prob = F.log_softmax(x, dim=-1) 45 | nll_loss = -log_prob.gather(dim=-1, index=target.unsqueeze(1)) 46 | nll_loss = nll_loss.squeeze(1) 47 | smooth_loss = -log_prob.mean(dim=-1) 48 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 49 | return loss 50 | 51 | def forward(self, x, target): 52 | return self._compute_losses(x, target).mean() 53 | 54 | 55 | class SoftCrossEntropy(torch.nn.Module): 56 | def __init__(self): 57 | super(SoftCrossEntropy, self).__init__() 58 | 59 | def forward(self, inputs, targets, temperature=1.): 60 | log_likelihood = -F.log_softmax(inputs / temperature, dim=1) 61 | likelihood = F.softmax(targets / temperature, dim=1) 62 | sample_num, class_num = targets.shape 63 | loss = torch.sum(torch.mul(log_likelihood, likelihood)) / sample_num 64 | return loss 65 | 66 | 67 | class UnilateralMse(torch.nn.Module): 68 | def __init__(self, thresh): 69 | super(UnilateralMse, self).__init__() 70 | self.thresh = thresh 71 | self.loss = torch.nn.MSELoss() 72 | 73 | def forward(self, x, target): 74 | # x = nn.functional.softmax(x, dim=1) 75 | torch.clip(x, max=self.thresh) 76 | # print(x) 77 | return self.loss(x, torch.zeros(x.shape, device=x.device).scatter_(1, target.view(-1, 1), self.thresh)) 78 | 79 | 80 | class WarmUpLoss(torch.nn.Module): 81 | def __init__(self): 82 | super(WarmUpLoss, self).__init__() 83 | self.ce = torch.nn.CrossEntropyLoss() 84 | 85 | def forward(self, x, target, epoch=15): 86 | x = nn.functional.softmax(x, dim=1) 87 | return self.ce(x, target) 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /code/networks.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from encoders import Encoder 4 | from operations import * 5 | from common import * 6 | from nodes import * 7 | 8 | 9 | class ConvNet(nn.Module): 10 | def __init__(self, 11 | step, 12 | dataset, 13 | num_classes, 14 | encode_type, 15 | node, 16 | *args, 17 | **kwargs): 18 | super(ConvNet, self).__init__() 19 | self.step = step 20 | self.dataset = dataset 21 | self.num_classes = num_classes 22 | self.node = eval(node) if type(node) == str else node 23 | self.warm_up = False 24 | 25 | if 'threshold' in kwargs.keys(): 26 | self.threshold = kwargs['threshold'] 27 | else: 28 | self.threshold = .5 29 | if 'decay' in kwargs.keys(): 30 | self.decay = kwargs['decay'] 31 | else: 32 | self.decay = 1. 33 | self.node = eval(node) if type(node) == str else node 34 | 35 | self.encoder = Encoder(self.step, encode_type) 36 | if dataset == 'mnist' or dataset == 'fashion': 37 | self.fun = nn.ModuleList([ 38 | nn.Conv2d(1, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 39 | nn.BatchNorm2d(128), 40 | nn.ReLU(), 41 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 42 | nn.MaxPool2d(2), 43 | 44 | nn.Conv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 45 | nn.BatchNorm2d(128), 46 | nn.ReLU(), 47 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 48 | nn.Conv2d(128, 256, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 49 | nn.BatchNorm2d(256), 50 | nn.ReLU(), 51 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 52 | nn.AvgPool2d(2), 53 | 54 | nn.Flatten(), 55 | nn.Linear(7 * 7 * 256, 2048), 56 | nn.ReLU(), 57 | NDropout(.5), 58 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 59 | nn.Linear(4096, 10 * num_classes), 60 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 61 | VotingLayer(10) 62 | ]) 63 | 64 | elif dataset == 'dvsg' or dataset == 'dvsc10': 65 | self.fun = nn.ModuleList([ 66 | nn.Conv2d(2, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 67 | nn.BatchNorm2d(128), 68 | nn.ReLU(), 69 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 70 | nn.Conv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 71 | nn.BatchNorm2d(128), 72 | nn.ReLU(), 73 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 74 | nn.MaxPool2d(2), 75 | 76 | nn.Conv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 77 | nn.BatchNorm2d(128), 78 | nn.ReLU(), 79 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 80 | nn.MaxPool2d(2), 81 | 82 | nn.Conv2d(128, 256, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 83 | nn.BatchNorm2d(256), 84 | nn.ReLU(), 85 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 86 | nn.MaxPool2d(2), 87 | 88 | nn.Conv2d(256, 512, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 89 | nn.BatchNorm2d(512), 90 | nn.ReLU(), 91 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 92 | nn.AvgPool2d(2), 93 | 94 | nn.Conv2d(512, 1024, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 95 | nn.BatchNorm2d(1024), 96 | nn.ReLU(), 97 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 98 | nn.AvgPool2d(4), 99 | 100 | nn.Flatten(), 101 | NDropout(.5), 102 | nn.Linear(1024, 10 * num_classes), 103 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 104 | VotingLayer(10), 105 | ]) 106 | 107 | elif dataset == 'cifar10' or dataset == 'cifar100': 108 | self.fun = nn.ModuleList([ 109 | nn.Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 110 | nn.BatchNorm2d(128), 111 | nn.ReLU(), 112 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 113 | nn.Conv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 114 | nn.BatchNorm2d(128), 115 | nn.ReLU(), 116 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 117 | nn.MaxPool2d(2), 118 | 119 | # nn.Conv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 120 | # nn.BatchNorm2d(128), 121 | # nn.ReLU(), 122 | # self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 123 | nn.Conv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 124 | nn.BatchNorm2d(128), 125 | nn.ReLU(), 126 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 127 | nn.MaxPool2d(2), 128 | 129 | nn.Conv2d(128, 256, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 130 | nn.BatchNorm2d(256), 131 | nn.ReLU(), 132 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 133 | # nn.Conv2d(256, 256, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 134 | # nn.BatchNorm2d(256), 135 | # nn.ReLU(), 136 | # self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 137 | nn.MaxPool2d(2), 138 | 139 | nn.Conv2d(256, 512, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 140 | nn.BatchNorm2d(512), 141 | nn.ReLU(), 142 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 143 | # nn.Conv2d(512, 512, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 144 | # nn.BatchNorm2d(512), 145 | # nn.ReLU(), 146 | # self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 147 | nn.AvgPool2d(4), 148 | 149 | nn.Flatten(), 150 | nn.Linear(512, 10 * num_classes), 151 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 152 | VotingLayer(10), 153 | ]) 154 | 155 | elif dataset == 'nmnist': 156 | self.fun = nn.ModuleList([ 157 | nn.Conv2d(2, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 158 | nn.BatchNorm2d(128), 159 | nn.ReLU(), 160 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 161 | nn.MaxPool2d(2), 162 | 163 | nn.Conv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 164 | nn.BatchNorm2d(128), 165 | nn.ReLU(), 166 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 167 | nn.MaxPool2d(2), 168 | 169 | nn.Conv2d(128, 256, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 170 | nn.BatchNorm2d(256), 171 | nn.ReLU(), 172 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 173 | nn.AvgPool2d(2), 174 | 175 | nn.Conv2d(256, 512, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)), 176 | nn.BatchNorm2d(512), 177 | nn.ReLU(), 178 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 179 | nn.AvgPool2d(4), 180 | 181 | nn.Flatten(), 182 | NDropout(.5), 183 | nn.Linear(512, 10 * num_classes), 184 | self.node(threshold=self.threshold, decay=self.decay) if self.step != 1 else nn.Identity(), 185 | VotingLayer(10), 186 | ]) 187 | 188 | self.fire_rate = [] 189 | 190 | def forward(self, inputs): 191 | step = self.step if self.warm_up is False or len(inputs.shape) != 4 else 1 192 | 193 | if len(inputs.shape) == 4: 194 | inputs = self.encoder(inputs) 195 | else: 196 | inputs = inputs.permute(1, 0, 2, 3, 4) 197 | # print(inputs.float()) 198 | self.reset() 199 | 200 | if not self.training: 201 | self.fire_rate.clear() 202 | 203 | outputs = [] 204 | for t in range(step): 205 | x = inputs[t] 206 | if self.dataset == 'dvsg' or self.dataset == 'dvsc10': 207 | x = F.interpolate(x, size=[64, 64]) 208 | for layer in self.fun: 209 | if type(layer) == self.node and self.warm_up: 210 | continue 211 | 212 | x = layer(x) 213 | 214 | # if hasattr(layer, 'integral'): 215 | # print(((x > 0.).float()).sum() / np.product(x.shape)) 216 | outputs.append(x) 217 | 218 | if not self.training: 219 | self.fire_rate.append(self._get_fire_rate()) 220 | 221 | return sum(outputs) / len(outputs) 222 | 223 | def reset(self): 224 | for layer in self.fun: 225 | if hasattr(layer, 'n_reset'): 226 | layer.n_reset() 227 | 228 | def set_ltd(self, value): 229 | for layer in self.fun: 230 | if hasattr(layer, 'ltd'): 231 | layer.ltd = value 232 | 233 | def get_mem_loss(self): 234 | raise NotImplementedError 235 | 236 | def _get_fire_rate(self): 237 | outputs = [] 238 | for layer in self.fun: 239 | if hasattr(layer, 'get_fire_rate'): 240 | outputs.append(layer.get_fire_rate()) 241 | return outputs 242 | 243 | def get_fire_rate(self): 244 | x = np.array(self.fire_rate) 245 | x = x.mean(axis=0) 246 | return x.tolist() 247 | 248 | def get_threshold(self): 249 | outputs = [] 250 | for layer in self.fun: 251 | if hasattr(layer, 'threshold'): 252 | thresh = nn.Sigmoid()(layer.threshold.detach().clone()) 253 | outputs.append(thresh) 254 | return outputs 255 | 256 | def get_decay(self): 257 | outputs = [] 258 | for layer in self.fun: 259 | if hasattr(layer, 'decay'): 260 | # outputs.append(float(torch.nn.Sigmoid()(layer.decay.detach().clone()))) 261 | outputs.append(float(layer.decay.detach().clone())) 262 | return outputs 263 | 264 | def set_warm_up(self, flag): 265 | self.warm_up = flag 266 | for mod in self.modules(): 267 | if hasattr(mod, 'set_n_warm_up'): 268 | mod.set_n_warm_up(flag) 269 | -------------------------------------------------------------------------------- /code/nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch.nn import Parameter 8 | from common import ActFun 9 | 10 | 11 | class DGLIFNode(nn.Module): 12 | def __init__(self, 13 | threshold=.5, 14 | shape=None, 15 | decay=1.): 16 | super(DGLIFNode, self).__init__() 17 | self.shape = shape 18 | 19 | self.mem = None 20 | self.spike = None 21 | self.threshold = Parameter(torch.tensor(threshold), requires_grad=False) 22 | self.decay = Parameter(torch.tensor(decay), requires_grad=False) 23 | self.n_reset() 24 | 25 | def n_reset(self): 26 | self.mem = None #torch.zeros(self.shape, device=self.device) 27 | self.spike = None #torch.zeros(self.shape, device=self.device) 28 | 29 | def integral(self, inputs): 30 | if self.mem is None: 31 | self.mem = inputs 32 | else: 33 | self.mem += inputs 34 | 35 | def calc_spike(self): 36 | # thresh = nn.Sigmoid()(self.threshold) 37 | # self.spike = self.mem.clone() / thresh 38 | # self.spike[self.spike < 1.] = 0. 39 | # self.spike = thresh.detach() * self.spike / (self.mem.detach().clone() + 1e-12) 40 | # self.mem[self.mem >= thresh] = 0. 41 | if True: #self.training: 42 | self.spike = self.mem.clone() 43 | self.spike[(self.spike < self.threshold)] = 0. 44 | self.spike = self.spike / (self.mem.detach().clone() + 1e-12) #* self.threshold 45 | # print(self.spike) 46 | self.mem[(self.mem >= self.threshold)] = 0. 47 | 48 | self.mem = self.mem * self.decay 49 | 50 | def forward(self, inputs): 51 | self.integral(inputs) 52 | self.calc_spike() 53 | return self.spike 54 | 55 | def get_fire_rate(self): 56 | return float((self.spike.detach().sum())) / float(np.product(self.spike.shape)) 57 | 58 | def get_mem_loss(self): 59 | spike = self.spike[self.spike > 0.] 60 | return (spike - 1.) ** 2 61 | 62 | 63 | class HTGLIFNode(nn.Module): 64 | def __init__(self, 65 | threshold=.5, 66 | shape=None, 67 | decay=1.): 68 | super(HTGLIFNode, self).__init__() 69 | self.shape = shape 70 | 71 | self.mem = None 72 | self.spike = None 73 | self.threshold = Parameter(torch.tensor(threshold), requires_grad=False) 74 | self.decay = Parameter(torch.tensor(decay), requires_grad=False) 75 | 76 | self.warm_up = False 77 | 78 | self.n_reset() 79 | 80 | def n_reset(self): 81 | self.mem = None # torch.zeros(self.shape, device=self.device) 82 | self.spike = None # torch.zeros(self.shape, device=self.device) 83 | 84 | def integral(self, inputs): 85 | if self.mem is None: 86 | self.mem = inputs 87 | else: 88 | self.mem += inputs 89 | 90 | def calc_spike(self): 91 | spike = self.mem.clone() 92 | spike[(spike < self.threshold)] = 0. 93 | self.spike = spike / (self.mem.detach().clone() + 1e-12) 94 | self.mem = torch.where(self.mem >= self.threshold, torch.zeros_like(self.mem), self.mem) 95 | 96 | self.mem = self.mem + 0.2 * self.spike - 0.2 * self.spike.detach() 97 | self.mem = self.mem * self.decay 98 | 99 | def forward(self, inputs): 100 | if self.warm_up: 101 | return inputs 102 | else: 103 | self.integral(inputs) 104 | self.calc_spike() 105 | return self.spike 106 | 107 | def get_fire_rate(self): 108 | if self.spike is None: 109 | return 0. 110 | return float((self.spike.detach() >= self.threshold).sum()) / float(np.product(self.spike.shape)) 111 | 112 | def get_mem_loss(self): 113 | spike = self.spike[self.spike > 0.] 114 | return (spike - 1.) ** 2 115 | 116 | def set_n_warm_up(self, flag): 117 | self.warm_up = flag 118 | 119 | def set_n_threshold(self, thresh): 120 | self.threshold = Parameter(torch.tensor(thresh, dtype=torch.float), requires_grad=False) 121 | 122 | 123 | class SGLIFNode(nn.Module): 124 | def __init__(self, 125 | shape=None, 126 | threshold=None, 127 | decay=1.): 128 | super(SGLIFNode, self).__init__() 129 | self.shape = shape 130 | self.act_fun = ActFun.apply 131 | 132 | self.decay = Parameter(torch.tensor(decay), requires_grad=False) 133 | self.mem = None 134 | self.spike = None 135 | self.n_reset() 136 | 137 | def n_reset(self): 138 | self.mem = None 139 | self.spike = None 140 | 141 | def integral(self, inputs): 142 | if self.mem is None: 143 | self.mem = inputs 144 | else: 145 | self.mem += inputs 146 | 147 | def calc_spike(self): 148 | self.spike = self.act_fun(self.mem) 149 | self.mem = self.mem * self.decay * (1. - self.spike) 150 | # self.spike = self.mem.clone() 151 | # self.spike[(self.spike < 1.) & (self.spike > -1.)] = 0. 152 | # self.mem = self.mem * self.decay 153 | # self.mem[(self.mem >= 1.) | (self.mem <= -1.)] = 0. 154 | return self.spike 155 | 156 | def forward(self, inputs): 157 | self.integral(inputs) 158 | self.calc_spike() 159 | return self.spike 160 | 161 | def get_fire_rate(self): 162 | return float((self.spike.detach().sum())) / float(np.product(self.spike.shape)) 163 | -------------------------------------------------------------------------------- /code/operations.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch import einsum 8 | import torch.nn.functional as F 9 | from einops import rearrange, repeat 10 | from einops.layers.torch import Rearrange 11 | 12 | from nodes import * 13 | 14 | 15 | class VotingLayer(nn.Module): 16 | def __init__(self, voter_num: int): 17 | super().__init__() 18 | self.voting = nn.AvgPool1d(voter_num, voter_num) 19 | 20 | def forward(self, x: torch.Tensor): 21 | # x.shape = [N, voter_num * C] 22 | # ret.shape = [N, C] 23 | return self.voting(x.unsqueeze(1)).squeeze(1) 24 | 25 | 26 | class NDropout(nn.Module): 27 | def __init__(self, p): 28 | super(NDropout, self).__init__() 29 | self.p = p 30 | self.mask = None 31 | 32 | def n_reset(self): 33 | self.mask = None 34 | 35 | def create_mask(self, x): 36 | self.mask = F.dropout(torch.ones_like(x.data), self.p, training=True) 37 | 38 | def forward(self, x): 39 | if self.training: 40 | if self.mask is None: 41 | self.create_mask(x) 42 | 43 | return self.mask * x 44 | else: 45 | return x 46 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import argparse 5 | import logging 6 | 7 | from torch import optim 8 | from torch.cuda.amp import autocast 9 | from torch.cuda.amp import GradScaler 10 | from thop import profile, clever_format 11 | 12 | from networks import * 13 | from losses import LabelSmoothingCrossEntropy, UnilateralMse, WarmUpLoss, LabelSmoothingBCEWithLogitsLoss 14 | from common import * 15 | from datasets import get_dvsg_data, get_dvsc10_data, get_nmnist_data 16 | 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | scaler = GradScaler() 21 | 22 | torch.backends.cudnn.benchmark = True 23 | torch.set_num_threads(64) 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('model', type=str) 27 | parser.add_argument('--epochs', type=int, default=300) 28 | 29 | parser.add_argument('--dataset', type=str, default='cifar10') 30 | parser.add_argument('--batch-size', type=int, default=128) 31 | parser.add_argument('--device', type=int, default=-1) 32 | 33 | parser.add_argument('--step', type=int, default=16) 34 | parser.add_argument('--encode', type=str, default='direct') 35 | parser.add_argument('--node', type=str, default='HTGLIFNode') 36 | parser.add_argument('--thresh', type=float, default=.5) 37 | parser.add_argument('--decay', type=float, default=.9) 38 | 39 | parser.add_argument('--suffix', type=str, default='') 40 | parser.add_argument('--infer_every', type=int, default=1) 41 | parser.add_argument('--warmup', type=int, default=5) 42 | parser.add_argument('--optim', type=str, default='AdamW') 43 | parser.add_argument('--loss', type=str, default='mse') 44 | 45 | parser.add_argument('--lr', type=float, default=0.001) 46 | parser.add_argument('--momentum', type=float, default=0.9) 47 | parser.add_argument('--weight-decay', type=float, default=1e-4) 48 | # parser.add_argument('--grad-clip', type=float, default=5.) 49 | parser.add_argument('--lr_min', type=float, default=1e-5) 50 | parser.add_argument('--resume', type=str, default='') 51 | parser.add_argument('--base_dir', type=str, default='/path/to/checkpoint') 52 | parser.add_argument('--warm-up', type=int, default=0) 53 | parser.add_argument('--save', type=bool, default=False) 54 | args = parser.parse_args() 55 | 56 | print(args) 57 | torch.autograd.set_detect_anomaly(True) 58 | # DEBUG 59 | # torch.autograd.set_detect_anomaly(True) 60 | 61 | CKPT_DIR = args.base_dir 62 | # CKPT_DIR = './ckpt' 63 | fname = '%s_%s_%s_%s_%d_%s_%s_%s_%.2f.pt' % (args.model, args.loss, args.suffix, args.dataset, args.step, args.encode, 64 | 'finetune' if args.resume else '', args.node, args.thresh) 65 | flog = fname + '.txt' 66 | # log = open(os.path.join(CKPT_DIR, flog), 'w') 67 | 68 | print(fname) 69 | 70 | device = torch.device("cuda:%d" % args.device if args.device >= 0 else "cpu") 71 | best_acc = 0. 72 | 73 | num_classes = 10 74 | 75 | 76 | if args.dataset == 'dvsg': 77 | train_loader, test_loader, _, _ = get_dvsg_data(args.batch_size, args.step) 78 | num_classes = 11 79 | elif args.dataset == 'dvsc10': 80 | train_loader, test_loader, _, _ = get_dvsc10_data(args.batch_size, args.step) 81 | elif args.dataset == 'nmnist': 82 | train_loader, test_loader, _, _ = get_nmnist_data(args.batch_size, args.step) 83 | else: 84 | raise NotImplementedError 85 | 86 | model = eval(args.model)(step=args.step, 87 | dataset=args.dataset, 88 | batch_size=args.batch_size, 89 | num_classes=num_classes, 90 | device=device, 91 | encode_type=args.encode, 92 | node=args.node, 93 | threshold=args.thresh, 94 | decay=args.decay).to(device) 95 | 96 | # optimizer = optim.SGD( 97 | # [{'params': [param for name, param in model.named_parameters() if 'floyed' not in name]}], 98 | # lr=args.lr, 99 | # weight_decay=args.weight_decay 100 | # ) 101 | 102 | if args.optim == 'AdamW': 103 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 104 | elif args.optim == 'SGD': 105 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) 106 | else: 107 | raise NotImplementedError 108 | # 109 | # scheduler = optim.lr_scheduler.CosineAnnealingLR( 110 | # optimizer, T_max=float(args.epochs), eta_min=args.lr_min, last_epoch=-1 111 | # ) 112 | 113 | # scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 90], gamma=0.1) 114 | 115 | if args.loss == 'ce': 116 | criterion = nn.CrossEntropyLoss() 117 | elif args.loss == 'bce': 118 | criterion = nn.BCEWithLogitsLoss() 119 | elif args.loss == 'mse': 120 | criterion = UnilateralMse(1.) 121 | elif args.loss == 'sce': 122 | criterion = LabelSmoothingCrossEntropy() 123 | elif args.loss == 'sbce': 124 | criterion = LabelSmoothingBCEWithLogitsLoss() 125 | elif args.loss == 'umse': 126 | criterion = UnilateralMse(.5) 127 | elif args.loss == 'mixed': 128 | criterion = WarmUpLoss() 129 | else: 130 | raise NotImplementedError 131 | 132 | epoch_start = 0 133 | 134 | 135 | def train(): 136 | if args.warm_up != 0: 137 | model.set_warm_up(True) 138 | # criterion = nn.CrossEntropyLoss() 139 | for epoch in range(epoch_start, args.epochs): 140 | adjust_learning_rate(optimizer, epoch, args) 141 | 142 | # warm-up 143 | if args.warm_up != 0 and epoch == args.warm_up: 144 | model.set_warm_up(False) 145 | # criterion = UnilateralMse(1.) 146 | 147 | model.train() 148 | print("[EPOCH]: %d/%d" % (epoch, args.epochs)) 149 | loss_tot = AverageMeter() 150 | loss_sum = AverageMeter() 151 | acc_tot = AverageMeter() 152 | acc_sum = AverageMeter() 153 | 154 | # if epoch < 50: 155 | # model.set_ltd(False) 156 | # else: 157 | # model.set_ltd(True) 158 | 159 | for idx, data in enumerate(train_loader): 160 | # model train 161 | images = data[0].float().to(device) 162 | labels = data[1].to(device) 163 | # print(images.shape, labels.shape) 164 | 165 | optimizer.zero_grad() 166 | 167 | # with autocast(): 168 | outputs = model(images) 169 | loss = criterion(outputs, labels) if args.loss != 'mixed' else criterion(outputs, labels, epoch - epoch_start) 170 | # scaler.scale(loss).backward() 171 | loss.backward() 172 | # nn.utils.clip_grad_norm(model.parameters(), args.grad_clip) 173 | 174 | optimizer.step() 175 | 176 | acc, = accuracy(outputs, labels, topk=(1,)) 177 | loss_tot.update(loss.item(), outputs.shape[0]) 178 | loss_sum.update(loss.item(), outputs.shape[0]) 179 | acc_tot.update(acc, outputs.shape[0]) 180 | acc_sum.update(acc, outputs.shape[0]) 181 | 182 | # logging 183 | if idx % 1 == 0: 184 | # log.write('[train]:%d, loss:%.3f, accuracy:%.3f, m_lr:%.5f\n' 185 | # % (idx, loss_tot.avg, acc_tot.avg, 186 | # optimizer.param_groups[0]['lr'])) 187 | print('[train]:%d, loss:%.3f, accuracy:%.3f, m_lr:%.5f, ' 188 | % (idx, loss_tot.avg, acc_tot.avg, 189 | optimizer.param_groups[0]['lr']), end='\r') 190 | if idx % 10 == 0: 191 | loss_tot.reset() 192 | acc_tot.reset() 193 | 194 | print('[train total]:loss:%.3f, accuracy:%.3f, m_lr:%.5f' 195 | % (loss_sum.avg, acc_sum.avg, 196 | optimizer.param_groups[0]['lr']), end='\r') 197 | 198 | print('', end='\n') 199 | if epoch % args.infer_every == 0: 200 | infer(epoch) 201 | 202 | # scheduler.step() 203 | 204 | 205 | def infer(epoch): 206 | global best_acc 207 | model.eval() 208 | loss_tot = AverageMeter() 209 | acc_tot = AverageMeter() 210 | for idx, data in enumerate(test_loader): 211 | # model train 212 | with torch.no_grad(): 213 | images = data[0].float().to(device) 214 | labels = data[1].to(device) 215 | # print(images.shape, labels.shape) 216 | # with autocast(): 217 | outputs = model(images).detach() 218 | loss = criterion(outputs, labels) if args.loss != 'mixed' else criterion(outputs, labels, epoch - epoch_start) 219 | 220 | acc, = accuracy(outputs, labels, topk=(1,)) 221 | loss_tot.update(loss.item(), outputs.shape[0]) 222 | acc_tot.update(acc, outputs.shape[0]) 223 | # print(model.get_fire_rate()) 224 | # fire_rate = model.get_fire_rate() 225 | # logging 226 | # if idx % 1 == 0: 227 | s = '[I]:%d,ls:%.2f,acc:%.2f,' % (idx, loss_tot.avg, acc_tot.avg) 228 | s = s + 'fr:' + ''.join(['{:.2f},'.format(i) for i in model.get_fire_rate()]) 229 | # + 'th:' + ''.join(['{:.2f},'.format(i) for i in model.get_threshold()]) 230 | # + 'dy:' + ''.join(['{:.2f}, '.format(i) for i in model.get_decay()]) 231 | print(s, end='\r') 232 | 233 | print('', end='\n') 234 | if best_acc < acc_tot.avg: 235 | best_acc = acc_tot.avg 236 | ckpt = { 237 | 'network': model.state_dict(), 238 | 'epoch': epoch + 1, 239 | 'optimizer': optimizer.state_dict(), 240 | # 'scheduler': scheduler.state_dict(), 241 | 'best_acc': best_acc, 242 | } 243 | if args.save: 244 | save(ckpt) 245 | print('\033[0;32;40m[SAVE]\033[0m %.5f' % best_acc) 246 | else: 247 | print('\033[0;32;40m[BEST]\033[0m %.5f' % best_acc) 248 | 249 | 250 | def save(ckpt): 251 | if not os.path.exists(CKPT_DIR): 252 | os.mkdir(CKPT_DIR) 253 | torch.save(ckpt, os.path.join(CKPT_DIR, fname)) 254 | 255 | 256 | def load(fname): 257 | global model, scheduler, epoch_start, best_acc 258 | ckpt = torch.load(fname, map_location=device) 259 | print('[best accuracy]: %f' % ckpt['best_acc']) 260 | 261 | model.load_state_dict(ckpt['network'], strict=False) 262 | # optimizer.load_state_dict(ckpt['optimizer']) 263 | # scheduler = optim.lr_scheduler.CosineAnnealingLR( 264 | # optimizer, T_max=float(args.epochs), eta_min=args.lr_min, last_epoch=ckpt['epoch'] 265 | # ) 266 | print(ckpt['epoch']) 267 | epoch_start = ckpt['epoch'] 268 | best_acc = 0. # ckpt['best_acc'] 269 | 270 | 271 | def adjust_learning_rate(optimizer, epoch, args): 272 | lr = args.lr 273 | if hasattr(args, 'warmup') and epoch < args.warmup: 274 | lr = lr / (args.warmup - epoch) 275 | # elif not args.disable_cos: 276 | else: 277 | lr *= 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup) / (args.epochs - args.warmup))) 278 | 279 | for param_group in optimizer.param_groups: 280 | param_group['lr'] = lr 281 | 282 | # lr = args.lr * (0.1 ** (epoch // 80)) 283 | # for param_group in optimizer.param_groups: 284 | # param_group['lr'] = lr 285 | 286 | 287 | if __name__ == '__main__': 288 | # model.set_warm_up(True) 289 | # if args.dataset == 'imnet': 290 | # inputs = torch.rand(1, 3, 224, 224, device=device) 291 | # elif args.dataset == 'cifar10': 292 | # inputs = torch.rand(1, 3, 32, 32, device=device) 293 | # elif args.dataset == 'fashion' or 'mnist': 294 | # inputs = torch.rand(1, 1, 28, 28, device=device) 295 | # else: 296 | # raise NotImplementedError 297 | # model.eval() 298 | # flops, params = profile(model, inputs=(inputs,)) 299 | # flops, params = clever_format([flops, params], '%.3f') 300 | # model.set_warm_up(False) 301 | 302 | # model.train() 303 | # print('[FLOPS] {}, PARAMS {}'.format(flops, params)) 304 | 305 | if args.resume != '': 306 | load(args.resume) 307 | infer(0) 308 | train() 309 | print('[BEST accuracy] {}'.format(best_acc)) 310 | -------------------------------------------------------------------------------- /train_cifar10.sh: -------------------------------------------------------------------------------- 1 | cd code 2 | python train.py ConvNet --dataset cifar10 --device 0 --node HTGLIFNode --batch-size 128 --step 8 --warm-up 80 --epoch 500 -------------------------------------------------------------------------------- /train_dvsc10.sh: -------------------------------------------------------------------------------- 1 | cd code 2 | python train.py ConvNet --dataset dvsc10 --device 0 --node HTGLIFNode --batch-size 64 --step 16 -------------------------------------------------------------------------------- /train_dvsg.sh: -------------------------------------------------------------------------------- 1 | cd code 2 | python train.py ConvNet --dataset dvsg --device 2 --node HTGLIFNode --batch-size 16 -------------------------------------------------------------------------------- /train_nmnist.sh: -------------------------------------------------------------------------------- 1 | cd code 2 | python train.py ConvNet --dataset nmnist --device 0 --node HTGLIFNode --batch-size 64 --step 16 --------------------------------------------------------------------------------