├── LICENSE ├── README.md ├── apis ├── __init__.py ├── test.py └── train.py ├── collect_distribution.py ├── config.py ├── dataset.py ├── datasets ├── __init__.py ├── action │ └── dataloaders.py ├── event_drop.py ├── events_timeslices.py ├── gait │ └── dataloaders.py ├── gesture │ └── dataloaders.py ├── recogition │ └── dataloaders.py └── transforms.py ├── final.sh ├── firing_num.py ├── img └── abstract.png ├── model.py ├── models ├── Net.py ├── Net2.py ├── Net3.py ├── Net4.py ├── __init__.py └── module │ ├── Attn.py │ ├── LIF.py │ └── __init__.py ├── requirements.txt ├── train_test.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Inherent Redundancy in Spiking Neural Networks ICCV2023 2 | 3 | Spiking Neural Networks (SNNs) are well known as a promising energy-efficient alternative to conventional artificial neural networks. Subject to the preconceived impression that SNNs are sparse firing, the analysis and optimization of inherent redundancy in SNNs have been largely overlooked, thus the potential advantages of spike-based neuromorphic computing in accuracy and energy efficiency are interfered. In this work, we pose and focus on three key questions regarding the inherent redundancy in SNNs. We argue that the redundancy is induced by the spatio-temporal invariance of SNNs, which enhances the efficiency of parameter utilization but also invites lots of noise spikes. Further, we analyze the effect of spatio-temporal invariance on the spatio-temporal dynamics and spike firing of SNNs. Then, motivated by these analyses, we propose an Advance Spatial Attention (ASA) module to harness SNNs' redundancy, which can adaptively optimize their membrane potential distribution by a pair of individual spatial attention sub-modules. In this way, noise spike features are accurately regulated. Experimental results demonstrate that the proposed method can significantly drop the spike firing with better performance than state-of-the-art baselines. 4 | 5 | ![abstract](./img/abstract.png) 6 | 7 | Our module's details are in `./models/module/Attn.py`. 8 | 9 | ## train & test 10 | 11 | ### requirements 12 | 13 | - python >= 3.7 14 | - torch >= 1.10 15 | - torchvision >= 0.11 16 | - spikingjelly == 0.0.0.0.12 17 | - h5py 18 | - pandas 19 | - einops 20 | - seaborn: only for plot features 21 | - torchinfo: only for see the details of networks 22 | 23 | ### quick-start 24 | 25 | ```shell 26 | sh final.sh 27 | ``` 28 | 29 | ### train on your own dataset 30 | 31 | ```shell 32 | CUDA_VISIBLE_DEVICES=0 python3 train_test.py \ 33 | -arch A \ 34 | -T 60 \ 35 | -attention ASA \ 36 | -dataset yourdata \ 37 | -data_path DATA_PATH \ 38 | -batch_size_train 32 \ 39 | -batch_size_test 4 \ 40 | -clip 10 \ 41 | -num_epochs 200 \ 42 | -seed 0 \ 43 | -ps 00011 44 | ``` 45 | 46 | ### args' helper 47 | 48 | ```shell 49 | python3 train_test.py -h 50 | ``` 51 | 52 | 1. `arch`: A for three-layer SNN; B for five-layer SNN 53 | 2. `T`: the time steps you want to set 54 | 3. `attention`: the module you want to use, "no" or "ASA" in this project 55 | 4. `dataset`: the dataset you want to use 56 | 5. `data_path`: the path of your dataset 57 | 6. `batch_size_train / test`: the batch-size used in training and inference 58 | 7. `clip`: the times you want to set for data augments algorithm "RCS" 59 | 8. `num_epochs`: the epochs you want to set 60 | 9. `seed`: the random seed you want to set 61 | 10. `ps`: the dropput is on / off on different layers 62 | 63 | > configs' details could be seen in `./config.py` 64 | 65 | ## performance 66 | 67 | | Dataset | Model | Acc.(%) | NASFR 68 | | -- | -- | -- | -- | 69 | | DVS128 Gesture | LIF-SNN | 91.3 | 0.176 | 70 | | | . + ASA | 95.2(**+3.9**) | 0.038(**-78.4%**) | 71 | | DVS128 Gesture by spikingjelly | LIF-SNN-Five | 95.5 | 0.023 | 72 | | | . + ASA | 97.7(**+2.2**) | 0.018(**-21.7%**) | 73 | | DVS128 Gait-day | LIF-SNN | 88.6 | 0.214 | 74 | | | . + ASA | 93.6(**+5.0**) | 0.045(**-78.9%**) | 75 | | DVS128 Gait-night | LIF-SNN | 96.4 | 0.197 | 76 | | | . + ASA | 98.6(**+2.2**) | 0.126(**-36.0%**) | 77 | | DailyAction-DVS | LIF-SNN | 92.5 | 0.017 | 78 | | | . + ASA | 94.6(**+2.1**) | 0.013(**-23.5%**) | 79 | | HAR-DVS | SEW-Res-SNN-18 | 45.5 | 0.206 | 80 | | | . + ASA | 47.1(**+1.6**) | 0.183(**-11.2%**) | 81 | 82 | ## Citation 83 | 84 | ``` 85 | @inproceedings{yao2023inherent, 86 | title={Inherent Redundancy in Spiking Neural Networks}, 87 | author={Yao, Man and Hu, Jiakui and Zhao, Guangshe and Wang, Yaoyuan and Zhang, Ziyang and Xu, Bo and Li, Guoqi}, 88 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 89 | pages={16924--16934}, 90 | year={2023} 91 | } 92 | ``` 93 | 94 | ## Contact Information 95 | 96 | For help or issues using this git, please submit a GitHub issue. 97 | 98 | For other communications related to this git, please contact `manyao@ia.ac.cn` and `jkhu29@stu.pku.edu.cn`. 99 | -------------------------------------------------------------------------------- /apis/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from apis.train import train 5 | from apis.test import test 6 | 7 | 8 | def run(args): 9 | args.best_acc = 0 10 | args.best_epoch = 0 11 | 12 | args.epoch_list = [] 13 | args.loss_train_list = [] 14 | args.loss_test_list = [] 15 | args.acc_train_list = [] 16 | args.acc_test_list = [] 17 | 18 | if args.pretrained_path != None: 19 | pre_dict = torch.load(args.pretrained_path)["net"] 20 | pre = {} 21 | for k, _ in pre_dict.items(): 22 | pre[k[7:]] = pre_dict[k] 23 | model_dict = args.model.state_dict() 24 | pre_dict = {k: v for k, v in pre.items() if k in model_dict} 25 | # print(pre_dict.keys()) 26 | model_dict.update(pre_dict) 27 | args.model.load_state_dict(model_dict) 28 | print("loading model...") 29 | 30 | for args.epoch in range(args.num_epochs): 31 | args.model.train() 32 | train(args=args) 33 | 34 | args.train_loss = args.train_loss / len(args.train_loader) 35 | args.epoch_list.append(args.epoch + 1) 36 | args.train_acc = ( 37 | 100.0 * float(args.train_correct) / float(len(args.train_dataset)) 38 | ) 39 | print("epoch:", args.epoch + 1) 40 | print("dt:", args.dt) 41 | print("T:", args.T) 42 | print("Tarin loss:%.5f" % args.train_loss) 43 | print("Train acc: %.3f" % args.train_acc) 44 | 45 | if args.lr_scheduler: 46 | args.scheduler.step(args.epoch) 47 | 48 | args.loss_train_list.append(args.train_loss) 49 | args.acc_train_list.append(args.train_acc) 50 | 51 | # test 52 | with torch.no_grad(): 53 | args.model.eval() 54 | test(args=args) 55 | 56 | args.test_loss = args.test_loss / len(args.test_loader) 57 | args.test_acc = ( 58 | 100.0 * float(args.test_correct) / float(len(args.test_dataset)) 59 | ) 60 | args.loss_test_list.append(args.test_loss) 61 | print("Test loss:%.5f" % args.test_loss) 62 | print("Test acc: %.3f" % args.test_acc) 63 | 64 | args.acc_test_list.append(args.test_acc) 65 | 66 | if args.test_acc >= args.best_acc: 67 | args.best_epoch = args.epoch + 1 68 | args.best_acc = args.test_acc 69 | 70 | print("Saving..") 71 | 72 | if not os.path.exists(args.modelPath): 73 | os.makedirs(args.modelPath) 74 | torch.save( 75 | args.model.state_dict(), 76 | args.modelPath + os.sep + args.modelNames, 77 | ) 78 | 79 | print("best acc:", args.best_acc) 80 | -------------------------------------------------------------------------------- /apis/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm import tqdm 4 | from spikingjelly.clock_driven import functional 5 | 6 | 7 | def test(args): 8 | args.test_loss = 0 9 | args.test_correct = 0 10 | bar_test = tqdm(total=len(args.test_loader)) 11 | for _, (input, labels) in enumerate(args.test_loader): 12 | functional.reset_net(args.model) 13 | 14 | b = input.size()[0] 15 | 16 | if "sj" not in args.dataset: 17 | input = input.reshape( 18 | b * args.clip, 19 | input.size()[2], 20 | input.size()[3], 21 | input.size()[4], 22 | input.size()[5], 23 | ) 24 | input = input.float().to(args.device) 25 | if len(labels.shape) == 3: 26 | labels = labels.reshape(b * args.clip, labels.size()[2]) 27 | labels = labels.float().to(args.device) 28 | else: 29 | labels = labels.reshape( 30 | b * args.clip, labels.size()[2], labels.size()[3] 31 | ) 32 | labels = labels[:, 1, :].float().to(args.device) 33 | else: 34 | input = input.float().to(args.device) 35 | if len(labels.shape) == 2: 36 | labels = labels.float().to(args.device) 37 | elif len(labels.shape) == 1: 38 | labels = F.one_hot(labels, args.num_classes).float().to(args.device) 39 | 40 | outputs, _ = args.model(input) 41 | 42 | loss = args.criterion(outputs, labels) 43 | 44 | _, predicted = torch.max(outputs.data, 1) 45 | _, labelTest = torch.max(labels.data, 1) 46 | 47 | for i in range(b): 48 | predicted_clips = predicted[i * args.clip : (i + 1) * args.clip] 49 | labelTest_clips = labelTest[i * args.clip : (i + 1) * args.clip] 50 | test_clip_correct = (predicted_clips == labelTest_clips).sum().item() 51 | if test_clip_correct / args.clip > 0.5: 52 | args.test_correct += 1 53 | 54 | args.test_loss += loss.item() / args.clip 55 | 56 | functional.reset_net(args.model) 57 | 58 | bar_test.update() 59 | bar_test.set_description( 60 | "Test:Epoch[%d/%d]" % (args.epoch + 1, args.num_epochs) 61 | ) 62 | bar_test.set_postfix(Loss=loss.item()) 63 | 64 | bar_test.close() 65 | -------------------------------------------------------------------------------- /apis/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm import tqdm 4 | from spikingjelly.clock_driven import functional 5 | 6 | 7 | def train(args): 8 | args.train_loss = 0 9 | args.train_correct = 0 10 | bar_train = tqdm(total=len(args.train_loader)) 11 | 12 | for _, (input, labels) in enumerate(args.train_loader): 13 | functional.reset_net(args.model) 14 | 15 | args.optimizer.zero_grad() 16 | 17 | input = input.float().to(args.device) 18 | if len(labels.shape) == 2: 19 | labels = labels.float().to(args.device) 20 | elif len(labels.shape) == 1: 21 | labels = F.one_hot(labels, args.num_classes).float().to(args.device) 22 | else: 23 | labels = labels[:, 1, :].float().to(args.device) 24 | 25 | outputs, lasso = args.model(input) 26 | 27 | if args.fbs or args.gate: 28 | loss = args.criterion(outputs, labels) + args.lasso_lambda * lasso 29 | else: 30 | loss = args.criterion(outputs, labels) 31 | 32 | _, predicted = torch.max(outputs.data, 1) 33 | _, labelTest = torch.max(labels.data, 1) 34 | args.train_correct += (predicted == labelTest).sum().item() 35 | 36 | args.train_loss += loss.item() 37 | 38 | loss.backward() 39 | 40 | args.optimizer.step() 41 | 42 | bar_train.update() 43 | bar_train.set_description( 44 | "Train:Epoch[%d/%d]" % (args.epoch + 1, args.num_epochs) 45 | ) 46 | bar_train.set_postfix(Loss=loss.item()) 47 | 48 | functional.reset_net(args.model) 49 | 50 | bar_train.close() 51 | -------------------------------------------------------------------------------- /collect_distribution.py: -------------------------------------------------------------------------------- 1 | import os, time 2 | import h5py 3 | import torch 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | from spikingjelly.clock_driven import functional 8 | 9 | from model import create_net 10 | from dataset import create_dataset 11 | from config import parser 12 | from utils import set_seed 13 | 14 | 15 | @torch.no_grad() 16 | def firing_num(args): 17 | args.model.load_state_dict(torch.load(args.modelPath)) 18 | args.model.eval() 19 | 20 | save_name_file = os.path.join(args.name, "S_Feature") 21 | args.path = save_name_file 22 | if not os.path.exists(save_name_file): 23 | os.makedirs(save_name_file) 24 | f = h5py.File(os.path.join(save_name_file, "all.h5"), "w") 25 | 26 | all_idx = 0 27 | flag = True 28 | 29 | for n_iter, (input, labels) in enumerate(args.test_loader): 30 | functional.reset_net(args.model) 31 | 32 | b = input.size()[0] 33 | 34 | input = input.reshape( 35 | b * args.clip, 36 | input.size()[2], 37 | input.size()[3], 38 | input.size()[4], 39 | input.size()[5], 40 | ) 41 | input = input.float().to(args.device) 42 | if len(labels.shape) == 3: 43 | labels = labels.reshape(b * args.clip, labels.size()[2]) 44 | labels = labels.float().to(args.device) 45 | else: 46 | labels = labels.reshape(b * args.clip, labels.size()[2], labels.size()[3]) 47 | labels = labels[:, 1, :].float().to(args.device) 48 | 49 | output, firing_nums, mps = args.model(input=input, mp_collect=True) 50 | 51 | for i in range(len(firing_nums)): 52 | if i < 3: 53 | firing_num_mat = firing_nums[i].reshape( 54 | b * args.clip, 55 | -1, 56 | firing_nums[i].size()[2], 57 | firing_nums[i].size()[3], 58 | firing_nums[i].size()[4], 59 | ) 60 | # print(firing_num_mat.shape) 61 | f.create_dataset( 62 | str(all_idx) + "_S_" + str(i), 63 | data=firing_num_mat.squeeze().cpu().numpy(), 64 | ) 65 | 66 | if flag: 67 | if args.arch == "A": 68 | layer = 3 69 | flag = False 70 | 71 | _, predicted = torch.max(output.data, 1) 72 | _, labelTest = torch.max(labels.data, 1) 73 | for i in range(b): 74 | predicted_clips = predicted[i * args.clip : (i + 1) * args.clip] 75 | labelTest_clips = labelTest[i * args.clip : (i + 1) * args.clip] 76 | test_clip_correct = (predicted_clips == labelTest_clips).sum().item() 77 | if test_clip_correct / args.clip > 0.5: 78 | args.test_correct += 1 79 | args.test_acc = 100.0 * float(args.test_correct) / (float(n_iter + 1) * b) 80 | print( 81 | "iteration: {}/{}\tacc: {}".format( 82 | n_iter + 1, len(args.test_loader), args.test_acc 83 | ) 84 | ) 85 | 86 | all_idx += 1 87 | 88 | g = h5py.File(os.path.join(args.path, "all.h5"), "r") 89 | f = h5py.File(os.path.join(args.path, "mean.h5"), "w") 90 | 91 | names = [] 92 | data_all = [] 93 | for pre in ["S"]: 94 | for nex in range(layer): 95 | data_all.append([]) 96 | names.append(pre + "_" + str(nex)) 97 | 98 | for name in names: 99 | start_time = time.time() 100 | X = [] 101 | for idx in range(len(args.test_loader)): 102 | if idx == 0: 103 | X = g[str(idx) + "_" + name][()] 104 | else: 105 | X = X + g[str(idx) + "_" + name][()] 106 | # X.append(g[str(idx) + '_' + name][()]) 107 | # X = np.array(X) 108 | f.create_dataset(name, data=X / len(args.test_dataset)) 109 | # f1.create_dataset(name, data=np.std(X, axis=0)) 110 | print("costs:", time.time() - start_time) 111 | 112 | f.close() 113 | 114 | 115 | @torch.no_grad() 116 | def mp_num(args): 117 | args.model.load_state_dict(torch.load(args.modelPath)) 118 | args.model.eval() 119 | 120 | save_name_file = os.path.join(args.name, "S_mp") 121 | args.path = save_name_file 122 | if not os.path.exists(save_name_file): 123 | os.makedirs(save_name_file) 124 | f = h5py.File(os.path.join(save_name_file, "all.h5"), "w") 125 | 126 | all_idx = 0 127 | flag = True 128 | 129 | for n_iter, (input, labels) in enumerate(args.test_loader): 130 | functional.reset_net(args.model) 131 | 132 | b = input.size()[0] 133 | 134 | input = input.reshape( 135 | b * args.clip, 136 | input.size()[2], 137 | input.size()[3], 138 | input.size()[4], 139 | input.size()[5], 140 | ) 141 | input = input.float().to(args.device) 142 | if len(labels.shape) == 3: 143 | labels = labels.reshape(b * args.clip, labels.size()[2]) 144 | labels = labels.float().to(args.device) 145 | else: 146 | labels = labels.reshape(b * args.clip, labels.size()[2], labels.size()[3]) 147 | labels = labels[:, 1, :].float().to(args.device) 148 | 149 | output, _, mps = args.model(input=input, mp_collect=True) 150 | 151 | if flag: 152 | if args.arch == "A": 153 | layer = 3 154 | elif args.arch == "B": 155 | layer = 5 156 | flag = False 157 | 158 | for i in range(len(mps)): 159 | if i < layer: 160 | mp = mps[i].reshape( 161 | b * args.clip, 162 | -1, 163 | mps[i].size()[2], 164 | mps[i].size()[3], 165 | mps[i].size()[4], 166 | ) 167 | # print(np.sum(mp.cpu().numpy()>0.3) / mp.numel()) 168 | f.create_dataset( 169 | str(all_idx) + "_S_" + str(i), data=mp.squeeze().cpu().numpy() 170 | ) 171 | 172 | _, predicted = torch.max(output.data, 1) 173 | _, labelTest = torch.max(labels.data, 1) 174 | for i in range(b): 175 | predicted_clips = predicted[i * args.clip : (i + 1) * args.clip] 176 | labelTest_clips = labelTest[i * args.clip : (i + 1) * args.clip] 177 | test_clip_correct = (predicted_clips == labelTest_clips).sum().item() 178 | if test_clip_correct / args.clip > 0.5: 179 | args.test_correct += 1 180 | args.test_acc = 100.0 * float(args.test_correct) / float(n_iter + 1) 181 | print( 182 | "iteration: {}/{}\tacc: {}".format( 183 | n_iter + 1, len(args.test_dataset), args.test_acc 184 | ) 185 | ) 186 | 187 | all_idx += 1 188 | 189 | g = h5py.File(os.path.join(args.path, "all.h5"), "r") 190 | f = h5py.File(os.path.join(args.path, "mean.h5"), "w") 191 | 192 | names = [] 193 | data_all = [] 194 | for pre in ["S"]: 195 | for nex in range(layer): 196 | data_all.append([]) 197 | names.append(pre + "_" + str(nex)) 198 | 199 | for name in names: 200 | start_time = time.time() 201 | X = [] 202 | for idx in range(len(args.test_dataset)): 203 | if idx == 0: 204 | X = g[str(idx) + "_" + name][()] 205 | else: 206 | X = X + g[str(idx) + "_" + name][()] 207 | # X.append(g[str(idx) + '_' + name][()]) 208 | # X = np.array(X) 209 | # print(np.sum(X / len(args.test_dataset)>=0.3) / np.size(X)) 210 | f.create_dataset(name, data=X / len(args.test_dataset)) 211 | # f1.create_dataset(name, data=np.std(X, axis=0)) 212 | print("costs:", time.time() - start_time) 213 | 214 | f.close() 215 | 216 | 217 | def mean_std_firing(args): 218 | # args.path = os.path.join(args.name, "S_Feature") 219 | # print(os.path.join(args.path, args.mean_std + '.h5')) 220 | f = h5py.File(os.path.join(args.path, args.mean_std + ".h5"), "r") 221 | save_path_mean = args.path + os.sep + args.mean_std 222 | 223 | save_path_mean_name = os.path.join(save_path_mean, args.name) 224 | if not os.path.exists(save_path_mean_name): 225 | os.makedirs(save_path_mean_name) 226 | data_ = f[args.feature_name][()] 227 | f.close() 228 | 229 | _, T, C, _, _ = data_.shape 230 | 231 | for t in range(T): 232 | for c in range(C): 233 | data = np.flipud(data_[0, t, c, :, :][()].T) 234 | # print(data.shape) 235 | if data.shape[0] == 32: 236 | fig = plt.figure(figsize=(50, 50), clear=True) 237 | elif data.shape[0] == 16: 238 | fig = plt.figure(figsize=(25, 25), clear=True) 239 | elif data.shape[0] == 8: 240 | fig = plt.figure(figsize=(12, 12), clear=True) 241 | else: 242 | fig = plt.figure(figsize=(5, 5), clear=True) 243 | 244 | ax = sns.heatmap( 245 | data=data, 246 | annot=True, 247 | fmt=".2f", 248 | cbar=False, 249 | cmap="RdBu_r", 250 | vmin=data.min(), 251 | vmax=data.max(), 252 | ) 253 | plt.xlabel("max:" + str(data.max()) + "_min:" + str(data.min())) 254 | plt.tight_layout() 255 | plt.title("C_" + str(c) + "T_" + str(t)) 256 | # plt.show() 257 | plt.savefig( 258 | os.path.join( 259 | save_path_mean_name, "C_" + str(c) + "T_" + str(t) + ".png" 260 | ), 261 | bbox_inches="tight", 262 | pad_inches=0, 263 | ) 264 | 265 | plt.clf() 266 | plt.cla() 267 | plt.close(fig) 268 | 269 | 270 | def mean_std_mp(args): 271 | # args.path = os.path.join(args.name, "S_Feature") 272 | # print(os.path.join(args.path, args.mean_std + '.h5')) 273 | f = h5py.File(os.path.join(args.path, args.mean_std + ".h5"), "r") 274 | save_path_mean = args.path + os.sep + args.mean_std 275 | 276 | save_path_mean_name = os.path.join(save_path_mean, args.feature_name) 277 | if not os.path.exists(save_path_mean_name): 278 | os.makedirs(save_path_mean_name) 279 | data_ = f[args.feature_name][()] 280 | f.close() 281 | 282 | _, T, C, _, _ = data_.shape 283 | 284 | spike = [] 285 | for t in range(T): 286 | for c in range(C): 287 | data = np.flipud(data_[0, t, c, :, :][()].T) 288 | spike.append(np.sum(data >= 0.3) / np.size(data)) 289 | 290 | fig = plt.figure(figsize=(5, 5), clear=True) 291 | 292 | sns.distplot(data, bins=50, kde=False) 293 | plt.xlabel("max:" + str(data.max()) + "_min:" + str(data.min())) 294 | plt.tight_layout() 295 | plt.title("C_" + str(c) + "T_" + str(t)) 296 | # plt.show() 297 | plt.savefig( 298 | os.path.join( 299 | save_path_mean_name, "C_" + str(c) + "T_" + str(t) + ".png" 300 | ), 301 | bbox_inches="tight", 302 | pad_inches=0, 303 | ) 304 | print("C_" + str(c) + "T_" + str(t)) 305 | 306 | plt.clf() 307 | plt.cla() 308 | plt.close(fig) 309 | 310 | 311 | def main(): 312 | args = parser.parse_args() 313 | 314 | set_seed(args.seed) 315 | 316 | args.test_correct = 0.0 317 | args.recordPath = args.modelPath 318 | args.im_width, args.im_height = (128 // args.ds, 128 // args.ds) 319 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 320 | args.device_ids = range(torch.cuda.device_count()) 321 | 322 | # args.name = args.dataset + '_dt=' + str(args.dt) + 'ms' + '_T=' + str(args.T) + '_attn=' + args.attention 323 | args.name = ( 324 | args.dataset 325 | + "_dt=" 326 | + str(args.dt) 327 | + "ms" 328 | + "_T=" 329 | + str(args.T) 330 | + "_attn=" 331 | + args.attention 332 | + "_seed=" 333 | + str(args.seed) 334 | + "_arch=" 335 | + args.arch 336 | ) 337 | args.modelNames = args.name + ".pth" 338 | args.recordNames = args.name + ".csv" 339 | 340 | print(args) 341 | 342 | create_dataset(args=args) 343 | 344 | create_net(args=args) 345 | 346 | if args.collect == "firing": 347 | firing_num(args=args) 348 | # args.mean_std = "mean" 349 | # mean_std_firing(args=args) 350 | elif args.collect == "mp": 351 | mp_num(args=args) 352 | # args.mean_std = "mean" 353 | # mean_std_mp(args=args) 354 | 355 | 356 | if __name__ == "__main__": 357 | main() 358 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument("-arch", type=str, default="A") 9 | parser.add_argument("-ds", type=int, default=4) 10 | parser.add_argument("-dt", type=int, default=15) 11 | parser.add_argument("-T", type=int, default=60) 12 | parser.add_argument("-seed", type=int, default=118) 13 | parser.add_argument("-epoch", type=int, default=0) 14 | parser.add_argument("-num", type=int, default=3) 15 | parser.add_argument("-num_epochs", type=int, default=100) 16 | parser.add_argument("-batch_size_train", type=int, default=32) 17 | parser.add_argument("-batch_size_test", type=int, default=4) 18 | 19 | # save 20 | parser.add_argument("-name", type=str, default=None) 21 | parser.add_argument( 22 | "-modelPath", 23 | type=str, 24 | default=os.path.dirname(os.path.abspath(__file__)) + os.sep + "result/", 25 | ) 26 | parser.add_argument("-modelNames", type=str, default=None) 27 | parser.add_argument("-recordPath", type=str, default=None) 28 | parser.add_argument("-recordNames", type=str, default=None) 29 | 30 | # Data 31 | parser.add_argument("-fbs", type=str, default=False) 32 | parser.add_argument("-lasso_lambda", type=float, default=1e-8) 33 | parser.add_argument("-sparsity_ratio", type=float, default=0.5) 34 | parser.add_argument("-dataset", type=str, default="gesture") 35 | parser.add_argument("-data_path", type=str, default="/data1/DVSGesture") 36 | 37 | # FBS 38 | parser.add_argument("-clip", type=int, default=10) 39 | parser.add_argument("-is_train_Enhanced", type=bool, default=True) 40 | parser.add_argument("-is_spike", type=bool, default=False) 41 | parser.add_argument("-interval_scaling", type=bool, default=False) 42 | 43 | # Gate 44 | parser.add_argument("-gate", type=bool, default=False) 45 | parser.add_argument("-target_rate", type=float, default=0.7) 46 | 47 | # Param 48 | parser.add_argument("-init_method", type=str, default=None) 49 | parser.add_argument("-pretrained_path", default=None) 50 | parser.add_argument("-criterion", default=nn.MSELoss()) 51 | 52 | # Network 53 | parser.add_argument("-in_channels", default=2) 54 | parser.add_argument("-num_classes", type=int, default=11) 55 | parser.add_argument("-beta", type=float, default=0) 56 | parser.add_argument("-alpha", type=float, default=0.3) 57 | parser.add_argument("-Vreset", type=float, default=0) 58 | parser.add_argument("-Vthres", type=float, default=0.3) 59 | parser.add_argument("-mem_act", default=torch.relu) 60 | parser.add_argument("-mode_select", type=str, default="spike") 61 | parser.add_argument("-TR_model", type=str, default="NTR") 62 | parser.add_argument("-track_running_stats", type=bool, default=True) 63 | parser.add_argument("-lens", type=float, default=0.25) 64 | parser.add_argument("-ps", type=str) 65 | 66 | # optimizer 67 | parser.add_argument("-lr", type=float, default=1e-4) 68 | parser.add_argument("-betas", default=[0.9, 0.999]) 69 | parser.add_argument("-eps", type=float, default=1e-8) 70 | parser.add_argument("-weight_decay", type=float, default=1e-4) 71 | parser.add_argument("-lr_scheduler", type=bool, default=True) 72 | 73 | # Attention 74 | parser.add_argument("-attention", type=str, default="no") 75 | parser.add_argument("-attention_before_conv", type=bool, default=False) 76 | parser.add_argument("-attention_per_time", type=bool, default=False) 77 | parser.add_argument("-attention_in_spike", type=bool, default=False) 78 | parser.add_argument("-reduction", type=int, default=1) 79 | parser.add_argument("-lam", type=float, default=0.5) 80 | 81 | # Dataloder 82 | parser.add_argument("-drop_last", type=bool, default=False) 83 | parser.add_argument("-pip_memory", type=bool, default=True) 84 | parser.add_argument("-num_work", type=int, default=4) 85 | 86 | parser.add_argument("-collect", type=str, default="firing") 87 | parser.add_argument("-feature_name", type=str, default="S_0") 88 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def create_dataset(args): 5 | if args.dataset == "gesture": 6 | from datasets.gesture.dataloaders import create_datasets 7 | elif args.dataset == "gesture_sj": 8 | from spikingjelly.datasets.dvs128_gesture import DVS128Gesture 9 | 10 | def create_datasets(root, train, **kwargs): 11 | return DVS128Gesture( 12 | root, 13 | train=train, 14 | data_type="frame", 15 | frames_number=args.T, 16 | split_by="number", 17 | ) 18 | 19 | elif args.dataset == "gait_night" or args.dataset == "gait_day": 20 | from datasets.gait.dataloaders import create_datasets 21 | elif args.dataset == "recogition": 22 | from datasets.recogition.dataloaders import create_datasets 23 | elif args.dataset == "action": 24 | from datasets.action.dataloaders import create_datasets 25 | 26 | args.train_dataset = create_datasets( 27 | args.data_path, 28 | train=True, 29 | is_train_Enhanced=args.is_train_Enhanced, 30 | ds=args.ds, 31 | dt=args.dt * 1000, 32 | is_spike=args.is_spike, 33 | interval_scaling=args.interval_scaling, 34 | T=args.T, 35 | ) 36 | args.test_dataset = create_datasets( 37 | args.data_path, 38 | train=False, 39 | ds=args.ds, 40 | dt=args.dt * 1000, 41 | clip=args.clip, 42 | is_spike=args.is_spike, 43 | interval_scaling=args.interval_scaling, 44 | T=args.T, 45 | ) 46 | 47 | # Data loader 48 | args.train_loader = torch.utils.data.DataLoader( 49 | args.train_dataset, 50 | batch_size=args.batch_size_train, 51 | shuffle=True, 52 | drop_last=args.drop_last, 53 | num_workers=args.num_work, 54 | pin_memory=args.pip_memory, 55 | ) 56 | args.test_loader = torch.utils.data.DataLoader( 57 | args.test_dataset, 58 | batch_size=args.batch_size_test, 59 | shuffle=False, 60 | drop_last=args.drop_last, 61 | num_workers=args.num_work, 62 | pin_memory=args.pip_memory, 63 | ) 64 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/action/dataloaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import struct 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torchvision.transforms import Compose,RandomCrop,CenterCrop 7 | 8 | from datasets.transforms import find_first, Repeat, toOneHot, ToTensor 9 | from datasets.events_timeslices import chunk_evs_pol_dvs_gesture, get_tmad_slice 10 | 11 | 12 | mapping = { 13 | 'climb':0, 14 | 'stand':1, 15 | 'jump':2, 16 | 'fall':3, 17 | 'sit':4, 18 | 'get up':5, 19 | 'walk':6, 20 | 'run':7, 21 | 'lift':8, 22 | 'lie':9, 23 | 'bend':10, 24 | 'pick':11 25 | } 26 | 27 | 28 | class DailyActionDataset(Dataset): 29 | def __init__( 30 | self, 31 | root, 32 | train=True, 33 | transform=None, 34 | target_transform=None, 35 | clip=10, 36 | is_train_Enhanced=False, 37 | dt=1000, 38 | size=[2, 32, 32], 39 | ds=4, 40 | is_spike=False, 41 | interval_scaling=False, 42 | T=16, 43 | ): 44 | super(DailyActionDataset, self).__init__() 45 | 46 | self.n = 0 47 | self.root = root 48 | self.train = train 49 | self.chunk_size = T 50 | self.clip = clip 51 | self.is_train_Enhanced = is_train_Enhanced 52 | self.dt = dt 53 | self.transform = transform 54 | self.target_transform = target_transform 55 | self.size = size 56 | self.ds = ds 57 | 58 | self.is_spike = is_spike 59 | self.interval_scaling = interval_scaling 60 | if train: 61 | self.class_dir_train = os.listdir(self.root) 62 | else: 63 | self.class_dir_test = os.listdir(self.root) 64 | 65 | if train: 66 | root_train = self.root 67 | self.files_train = [] 68 | for self.now_file, self.label_train_one, self.files_train_one in os.walk(root_train): 69 | i = 0 70 | for now_data_file in self.files_train_one: 71 | if i < 120 * 0.8: 72 | self.files_train.append(os.path.join(self.now_file,now_data_file)) 73 | i += 1 74 | self.n = len(self.files_train) 75 | else: 76 | root_test = self.root 77 | self.files_test = [] 78 | for self.now_file, self.label_train_one, self.files_train_one in os.walk(root_test): 79 | i = 0 80 | for now_data_file in self.files_train_one: 81 | if i >= 120 * 0.8: 82 | self.files_test.append(os.path.join(self.now_file,now_data_file)) 83 | i += 1 84 | self.n = len(self.files_test) 85 | 86 | def __len__(self): 87 | return self.n 88 | 89 | def __getitem__(self, idx): 90 | 91 | # Important to open and close in getitem to enable num_workers>0 92 | 93 | if self.train: 94 | file_path=self.files_train[idx] 95 | data, target, time_all = sample_train( 96 | file_path, 97 | T=self.chunk_size, 98 | is_train_Enhanced=self.is_train_Enhanced, 99 | dt=self.dt 100 | ) 101 | 102 | data = chunk_evs_pol_dvs_gesture( 103 | data=data, 104 | dt=self.dt, 105 | T=self.chunk_size, 106 | size=self.size, 107 | ds=self.ds 108 | ) 109 | if self.is_spike: 110 | data = np.int64(data > 0) 111 | if self.interval_scaling: 112 | data = data / data.max() 113 | 114 | if self.transform is not None: 115 | data = self.transform(data) 116 | 117 | if self.target_transform is not None: 118 | target = self.target_transform(target) 119 | 120 | return data, target 121 | else: 122 | file_path = self.files_test[idx] 123 | data, target, time_all = sample_test( 124 | file_path, 125 | T=self.chunk_size, 126 | clip=self.clip, 127 | dt=self.dt 128 | ) 129 | 130 | data_temp = [] 131 | target_temp = [] 132 | for i in range(self.clip): 133 | 134 | if self.transform is not None: 135 | temp = chunk_evs_pol_dvs_gesture( 136 | data=data[i], 137 | dt=self.dt, 138 | T=self.chunk_size, 139 | size=self.size, 140 | ds=self.ds 141 | ) 142 | 143 | if self.is_spike: 144 | temp = np.int64(temp > 0) 145 | 146 | if self.interval_scaling: 147 | temp = temp / temp.max() 148 | 149 | data_temp.append(self.transform(temp)) 150 | 151 | if self.target_transform is not None: 152 | target_temp.append(self.target_transform(target)) 153 | 154 | data = torch.stack(data_temp) 155 | target = torch.stack(target_temp) 156 | 157 | return data, target 158 | 159 | 160 | def getDVSeventsDavis(file, numEvents=1e10, startTime=0): 161 | """ DESCRIPTION: This function reads a given aedat file and converts it into four lists indicating 162 | timestamps, x-coordinates, y-coordinates and polarities of the event stream. 163 | 164 | Args: 165 | file: the path of the file to be read, including extension (str). 166 | numEvents: the maximum number of events allowed to be read (int, default value=1e10). 167 | startTime: the start event timestamp (in microseconds) where the conversion process begins (int, default value=0). 168 | 169 | Return: 170 | ts: list of timestamps in microseconds. 171 | x: list of x-coordinates in pixels. 172 | y: list of y-coordinates in pixels.` 173 | pol: list of polarities (0: on -> off, 1: off -> on). 174 | """ 175 | # print('\ngetDVSeventsDavis function called \n') 176 | sizeX = 128 177 | sizeY = 128 178 | x0 = 0 179 | y0 = 0 180 | x1 = sizeX 181 | y1 = sizeY 182 | 183 | # print('Reading in at most', str(numEvents)) 184 | 185 | triggerevent = int('400', 16) 186 | polmask = int('800', 16) 187 | xmask = int('003FF000', 16) 188 | ymask = int('7FC00000', 16) 189 | typemask = int('80000000', 16) 190 | typedvs = int('00', 16) 191 | xshift = 12 192 | yshift = 22 193 | polshift = 11 194 | x = [] 195 | y = [] 196 | ts = [] 197 | pol = [] 198 | numeventsread = 0 199 | 200 | length = 0 201 | aerdatafh = open(file, 'rb') 202 | k = 0 203 | p = 0 204 | statinfo = os.stat(file) 205 | if length == 0: 206 | length = statinfo.st_size 207 | # print("file size", length) 208 | 209 | lt = aerdatafh.readline() 210 | while lt and str(lt)[2] == "#": 211 | p += len(lt) 212 | k += 1 213 | lt = aerdatafh.readline() 214 | continue 215 | 216 | aerdatafh.seek(p) 217 | tmp = aerdatafh.read(8) 218 | p += 8 219 | while p < length: 220 | ad, tm = struct.unpack_from('>II', tmp) 221 | ad = abs(ad) 222 | if tm >= startTime: 223 | if (ad & typemask) == typedvs: 224 | xo = sizeX - 1 - float((ad & xmask) >> xshift) 225 | yo = float((ad & ymask) >> yshift) 226 | polo = 1 - float((ad & polmask) >> polshift) 227 | if xo >= x0 and xo < x1 and yo >= y0 and yo < y1: 228 | x.append(xo) 229 | y.append(yo) 230 | pol.append(polo) 231 | ts.append(tm) 232 | aerdatafh.seek(p) 233 | tmp = aerdatafh.read(8) 234 | p += 8 235 | numeventsread += 1 236 | 237 | # print('Total number of events read =', numeventsread) 238 | # print('Total number of DVS events returned =', len(ts)) 239 | 240 | return ts, x, y, pol 241 | 242 | 243 | def sample_train( 244 | hdf5_file, 245 | T=60, 246 | dt=1000, 247 | is_train_Enhanced=False 248 | ): 249 | label = mapping[hdf5_file.split('/')[-2]] 250 | data_dvs = np.array(getDVSeventsDavis(hdf5_file), dtype=np.int_) 251 | tbegin = data_dvs[0][0] 252 | time_all = data_dvs[0][-1] - data_dvs[0][0] 253 | if data_dvs[0][-1] - tbegin 0: 302 | overlap = int( 303 | np.floor((clip * T * dt - (end_time - start_time)) / clip)) 304 | for j in range(clip): 305 | start_point.append(j * (T * dt - overlap)) 306 | if start_point[-1] + T * dt > end_time: 307 | diff = start_point[-1] + T * dt - end_time 308 | start_point[-1] = start_point[-1] - diff 309 | else: 310 | overlap = int( 311 | np.floor(((end_time - start_time) - clip * T * dt) / clip)) 312 | for j in range(clip): 313 | start_point.append(j * (T * dt + overlap)) 314 | if start_point[-1] + T * dt > end_time: 315 | diff = start_point[-1] + T * dt - end_time 316 | start_point[-1] = start_point[-1] - diff 317 | 318 | temp = [] 319 | for start in start_point: 320 | idx_beg = find_first(tmad[:, 0], start) 321 | idx_end = find_first(tmad[:, 0][idx_beg:], start + T * dt) + idx_beg 322 | temp.append(tmad[idx_beg:idx_end][:, [0, 3, 1, 2]]) 323 | 324 | return temp, label, time_all 325 | 326 | 327 | def create_datasets( 328 | root=None, 329 | train=True, 330 | ds=4, 331 | dt=1000, 332 | transform=None, 333 | target_transform=None, 334 | n_events_attention=None, 335 | clip=10, 336 | is_train_Enhanced=False, 337 | is_spike = False, 338 | interval_scaling = False, 339 | T=16, 340 | ): 341 | if isinstance(ds, int): 342 | ds = [ds, ds] 343 | 344 | size = [2, 346 // ds[0], 260 // ds[1]] 345 | 346 | if n_events_attention is None: 347 | def default_transform(): 348 | return Compose([ 349 | ToTensor(), 350 | # CenterCrop(128) 351 | ]) 352 | else: 353 | def default_transform(): 354 | return Compose([ 355 | ToTensor(), 356 | # CenterCrop(128) 357 | ]) 358 | 359 | if transform is None: 360 | transform = default_transform() 361 | 362 | if target_transform is None: 363 | target_transform = Compose([ 364 | Repeat(T), toOneHot(12) 365 | ]) 366 | 367 | dataset = DailyActionDataset( 368 | root, 369 | train=train, 370 | transform=transform, 371 | target_transform=target_transform, 372 | is_train_Enhanced=is_train_Enhanced, 373 | dt=dt, 374 | size=size, 375 | ds=ds, 376 | is_spike=is_spike, 377 | interval_scaling=interval_scaling, 378 | T=T, 379 | clip=clip, 380 | ) 381 | return dataset 382 | 383 | -------------------------------------------------------------------------------- /datasets/event_drop.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | 5 | # EventDrop augmentation by randomly dropping events 6 | def event_drop(events, resolution): 7 | raw_events = events 8 | option = np.random.randint( 9 | 0, 4 10 | ) # 0: identity, 1: drop_by_time, 2: drop_by_area, 3: random_drop 11 | if option == 0: # identity, do nothing 12 | return events 13 | elif option == 1: # drop_by_time 14 | T = np.random.randint(1, 10) / 10.0 # np.random.uniform(0.1, 0.9) 15 | events = drop_by_time(events, T=T) 16 | elif option == 2: # drop by area 17 | area_ratio = ( 18 | np.random.randint(1, 20) / 20.0 19 | ) # np.random.uniform(0.05, 0.1, 0.15, 0.2, 0.25) 20 | events = drop_by_area(events, resolution=resolution, area_ratio=area_ratio) 21 | elif option == 3: # random drop 22 | ratio = np.random.randint(1, 10) / 10.0 # np.random.uniform(0.1, 0.9) 23 | events = random_drop(events, ratio=ratio) 24 | 25 | if len(events) == 0: # avoid dropping all the events 26 | events = raw_events 27 | 28 | return events 29 | 30 | 31 | # drop the events within a given time interval 32 | # param: events contains x, y, t, p information 33 | def drop_by_time(events, T=0): 34 | # assert 0.1 <= T <= 0.5 35 | 36 | # time interval 37 | t_start = np.random.uniform(0, 1) 38 | if T == 0: # randomly choose a value between [0.1, 0.9] 39 | T = np.random.randint(1, 10) / 10.0 40 | t_end = t_start + T 41 | 42 | timestamps = events[:, 0] 43 | max_t = max(timestamps) 44 | idx = (timestamps < (max_t * t_start)) | ( 45 | timestamps > (max_t * t_end) 46 | ) # remaining events that are not within the given time interval 47 | 48 | return events[idx] 49 | 50 | 51 | # drop the events within a a fixed area constrained by X, Y 52 | # area ratio: the ratio between the area whose pixels are dropped and the overall pixel range 53 | def drop_by_area(events, resolution=(128, 128), area_ratio=0): 54 | # assert 0.1 <= area_ratio <= 0.3 55 | 56 | # get the area whose events are to be dropped 57 | x0 = np.random.uniform(resolution[0]) 58 | y0 = np.random.uniform(resolution[1]) 59 | 60 | if area_ratio == 0: 61 | area_ratio = np.random.randint(1, 6) / 20.0 62 | 63 | x_out = resolution[0] * area_ratio 64 | y_out = resolution[1] * area_ratio 65 | 66 | x0 = int(max(0, x0 - x_out / 2.0)) 67 | y0 = int(max(0, y0 - y_out / 2.0)) 68 | 69 | x1 = min(resolution[0], x0 + x_out) 70 | y1 = min(resolution[1], y0 + y_out) 71 | 72 | xy = (x0, x1, y0, y1) # rectangele to be dropped 73 | 74 | idx1 = (events[:, 2] < xy[0]) | (events[:, 2] > xy[1]) 75 | idx2 = (events[:, 3] < xy[2]) | (events[:, 3] > xy[3]) 76 | idx = idx1 & idx2 77 | 78 | return events[idx] 79 | 80 | 81 | # randomly drop a proportion of events 82 | def random_drop(events, ratio=0): 83 | # assert 0.1 <= ratio <= 0.5 84 | 85 | if ratio == 0: 86 | ratio = np.random.randint(1, 10) / 10.0 87 | 88 | N = events.shape[0] # number of total events 89 | num_drop = int(N * ratio) # number of events to be dropped 90 | idx = random.sample(list(np.arange(0, N)), N - num_drop) 91 | 92 | return events[idx] 93 | 94 | 95 | # randomly shift events 96 | def random_shift_events(events, max_shift=20, resolution=(128, 128), p=0.5): 97 | H, W = resolution 98 | if np.random.random() < p: 99 | x_shift, y_shift = np.random.randint(-max_shift, max_shift + 1, size=(2,)) 100 | events[:, 2] += x_shift 101 | events[:, 3] += y_shift 102 | valid_events = ( 103 | (events[:, 2] >= 0) 104 | & (events[:, 2] < W) 105 | & (events[:, 3] >= 0) 106 | & (events[:, 3] < H) 107 | ) 108 | events = events[valid_events] 109 | return events 110 | 111 | 112 | # randomly flip events 113 | def random_flip_events_along_x(events, resolution=(128, 128), p=0.5): 114 | H, W = resolution 115 | if np.random.random() < p: 116 | events[:, 2] = W - 1 - events[:, 2] 117 | return events 118 | -------------------------------------------------------------------------------- /datasets/events_timeslices.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | 4 | 5 | def expand_targets(targets, T=500, burnin=0): 6 | y = np.tile(targets.copy(), [T, 1, 1]) 7 | y[:burnin] = 0 8 | return y 9 | 10 | 11 | def one_hot(mbt, num_classes): 12 | out = np.zeros([mbt.shape[0], num_classes]) 13 | out[np.arange(mbt.shape[0], dtype="int"), mbt.astype("int")] = 1 14 | return out 15 | 16 | 17 | def find_first(a, tgt): 18 | return bisect.bisect_left(a, tgt) 19 | 20 | 21 | def cast_evs(evs): 22 | ts = (evs[:, 0] * 1e6).astype("uint64") 23 | ad = (evs[:, 1:]).astype("uint64") 24 | return ts, ad 25 | 26 | 27 | def get_subsampled_coordinates(evs, ds_h, ds_w): 28 | x_coords = evs[:, 1] // ds_w 29 | y_coords = evs[:, 2] // ds_h 30 | if x_coords.dtype != np.int: 31 | x_coords = x_coords.astype(int) 32 | if y_coords.dtype != np.int: 33 | y_coords = y_coords.astype(int) 34 | return x_coords, y_coords 35 | 36 | 37 | def get_binary_frame_np(arr, evs, ds_w=1, ds_h=1): 38 | x_coords, y_coords = get_subsampled_coordinates(evs, ds_h, ds_w) 39 | arr[x_coords, y_coords] = 2 * evs[:, 3] - 1 40 | 41 | 42 | def get_binary_frame(arr, evs, ds_w=1, ds_h=1): 43 | x_coords, y_coords = get_subsampled_coordinates(evs, ds_h, ds_w) 44 | arr[x_coords, y_coords] = 1 45 | 46 | 47 | def get_slice(times, addrs, start_time, end_time): 48 | try: 49 | idx_beg = find_first(times, start_time) 50 | idx_end = find_first(times[idx_beg:], end_time) + idx_beg 51 | return times[idx_beg:idx_end] - times[idx_beg], addrs[idx_beg:idx_end] 52 | except IndexError: 53 | raise IndexError("Empty batch found") 54 | 55 | 56 | def get_event_slice(times, addrs, start_time, T, size=[128, 128], ds=1, dt=1000): 57 | try: 58 | idx_beg = find_first(times, start_time) 59 | idx_end = find_first(times[idx_beg:], start_time + T * dt) + idx_beg 60 | return chunk_evs_pol_dvs( 61 | times[idx_beg:idx_end], 62 | addrs[idx_beg:idx_end], 63 | deltat=dt, 64 | chunk_size=T, 65 | size=size, 66 | ds_w=ds, 67 | ds_h=ds, 68 | ) 69 | except IndexError: 70 | raise IndexError("Empty batch found") 71 | 72 | 73 | def get_tmad_slice(times, addrs, start_time, T): 74 | try: 75 | idx_beg = find_first(times, start_time) 76 | idx_end = find_first(times[idx_beg:], start_time + T) + idx_beg 77 | return np.column_stack([times[idx_beg:idx_end], addrs[idx_beg:idx_end]]) 78 | except IndexError: 79 | raise IndexError("Empty batch found") 80 | 81 | 82 | def get_time_surface(evs, invtau=1e-6, size=(346, 260, 2)): 83 | tr = np.zeros(size, "int64") - np.inf 84 | 85 | for ev in evs: 86 | tr[ev[2], ev[1], ev[3]] = ev[0] 87 | 88 | a = np.exp(tr[:, :, 0] * invtau) - np.exp(tr[:, :, 1] * invtau) 89 | 90 | return a 91 | 92 | 93 | def chunk_evs_dvs(evs, deltat=1000, chunk_size=500, size=[304, 240], ds_w=1, ds_h=1): 94 | t_start = evs[0, 0] 95 | ts = range(t_start + chunk_size, t_start + chunk_size * deltat, deltat) 96 | chunks = np.zeros([len(ts)] + size, dtype="int8") 97 | idx_start = 0 98 | idx_end = 0 99 | for i, t in enumerate(ts): 100 | idx_end += find_first(evs[idx_end:, 0], t) 101 | if idx_end > idx_start: 102 | get_binary_frame_np( 103 | chunks[i, ...], evs[idx_start:idx_end], ds_h=ds_h, ds_w=ds_w 104 | ) 105 | idx_start = idx_end 106 | return chunks 107 | 108 | 109 | def frame_evs(times, addrs, deltat=1000, duration=500, size=[240], downsample=[1]): 110 | t_start = times[0] 111 | ts = range(t_start, t_start + duration * deltat, deltat) 112 | chunks = np.zeros([len(ts)] + size, dtype="int8") 113 | idx_start = 0 114 | idx_end = 0 115 | for i, t in enumerate(ts): 116 | idx_end += find_first(times[idx_end:], t) 117 | if idx_end > idx_start: 118 | ee = addrs[idx_start:idx_end] 119 | ev = [(ee[:, i] // d).astype(np.int) for i, d in enumerate(downsample)] 120 | np.add.at(chunks, tuple([i] + ev), 1) 121 | idx_start = idx_end 122 | return chunks 123 | 124 | 125 | def chunk_evs_pol_dvs( 126 | times, addrs, deltat=1000, chunk_size=500, size=[2, 304, 240], ds_w=1, ds_h=1 127 | ): 128 | t_start = times[0] 129 | ts = range(t_start, t_start + chunk_size * deltat, deltat) 130 | chunks = np.zeros([len(ts)] + size, dtype="int8") 131 | idx_start = 0 132 | idx_end = 0 133 | for i, t in enumerate(ts): 134 | idx_end += find_first(times[idx_end:], t) 135 | if idx_end > idx_start: 136 | ee = addrs[idx_start:idx_end] 137 | pol, x, y = ( 138 | ee[:, 2], 139 | (ee[:, 0] // ds_w).astype(np.int), 140 | (ee[:, 1] // ds_h).astype(np.int), 141 | ) 142 | np.add.at(chunks, (i, pol, x, y), 1) 143 | idx_start = idx_end 144 | return chunks 145 | 146 | 147 | def chunk_evs_pol_dvs_gesture(data, dt=1000, T=500, size=[2, 304, 240], ds=[4, 4]): 148 | try: 149 | t_start = data[0][0] 150 | except Exception as e: 151 | t_start = 0 152 | ts = range(t_start, t_start + T * dt, dt) 153 | chunks = np.zeros([len(ts)] + size, dtype="int64") 154 | idx_start = 0 155 | idx_end = 0 156 | for i, t in enumerate(ts): 157 | idx_end += find_first(data[idx_end:, 0], t + dt) 158 | if idx_end > idx_start: 159 | ee = data[idx_start:idx_end, 1:] 160 | pol, x, y = ( 161 | ee[:, 0], 162 | np.floor(ee[:, 1] / ds[0]).astype(np.int64), 163 | np.floor(ee[:, 2] / ds[1]).astype(np.int64), 164 | ) 165 | np.add.at(chunks, (i, pol, x, y), 1) 166 | idx_start = idx_end 167 | return chunks 168 | 169 | 170 | def chunk_evs_pol_dvs_gait(data, dt=1000, T=500, size=[2, 304, 240], ds=[4, 4]): 171 | if len(data) == 0: 172 | return np.zeros( 173 | [T] + [2] + [int(128 / ds[0])] + [int(128 / ds[1])], dtype="int64" 174 | ) 175 | t_start = data[0][0] 176 | ts = range(t_start, t_start + T * dt, dt) 177 | chunks = np.zeros( 178 | [len(ts)] + [2] + [int(128 / ds[0])] + [int(128 / ds[1])], dtype="int64" 179 | ) 180 | idx_start = 0 181 | idx_end = 0 182 | for i, t in enumerate(ts): 183 | idx_end += find_first(data[idx_end:, 0], t + dt) 184 | if idx_end > idx_start: 185 | ee = data[idx_start:idx_end, 1:] 186 | pol, x, y = ( 187 | ee[:, 2], 188 | np.floor(ee[:, 0] / ds[0]).astype(np.int64), 189 | np.floor(ee[:, 1] / ds[1]).astype(np.int64), 190 | ) 191 | np.add.at(chunks, (i, pol, x, y), 1) 192 | idx_start = idx_end 193 | return chunks 194 | -------------------------------------------------------------------------------- /datasets/gait/dataloaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from torchvision.transforms import Compose 7 | 8 | from datasets.transforms import find_first, Repeat, toOneHot, ToTensor 9 | from datasets.events_timeslices import chunk_evs_pol_dvs_gait, get_tmad_slice 10 | from datasets.event_drop import event_drop 11 | 12 | 13 | class DVS128GaitDataset(Dataset): 14 | def __init__( 15 | self, 16 | path, 17 | dt=1000, 18 | T=10, 19 | train=True, 20 | is_train_Enhanced=False, 21 | clips=1, 22 | is_spike=False, 23 | ds=None, 24 | ): 25 | super(DVS128GaitDataset, self).__init__() 26 | if ds is None: 27 | ds = [1, 1] 28 | self.train = train 29 | self.dt = dt 30 | self.T = T 31 | self.is_train_Enhanced = is_train_Enhanced 32 | self.clips = clips 33 | self.is_spike = is_spike 34 | self.ds = ds 35 | 36 | if self.train: 37 | train_npy_path = os.path.join(path, 'train') 38 | self.train_data = np.load(os.path.join(train_npy_path, 'train_data.npy'), allow_pickle=True) 39 | self.train_target = np.load(os.path.join(train_npy_path, 'train_target.npy'), allow_pickle=True) 40 | else: 41 | test_npy_path = os.path.join(path, 'test') 42 | self.test_data = np.load(os.path.join(test_npy_path, 'test_data.npy'), allow_pickle=True) 43 | self.test_target = np.load(os.path.join(test_npy_path, 'test_target.npy'), allow_pickle=True) 44 | 45 | def __len__(self): 46 | if self.train: 47 | return len(self.train_data) 48 | else: 49 | return len(self.test_data) 50 | 51 | def __getitem__(self, idx): 52 | if self.train: 53 | data = self.train_data[idx] 54 | time_all = data[-1][0] - data[0][0] 55 | 56 | data = sample_train( 57 | data=data, 58 | dt=self.dt, 59 | T=self.T, 60 | is_train_Enhanced=self.is_train_Enhanced, 61 | ) 62 | data = event_drop(data, resolution=[128 // self.ds[0], 128 // self.ds[1]]) 63 | 64 | data = chunk_evs_pol_dvs_gait( 65 | data=data, 66 | dt=self.dt, 67 | T=self.T, 68 | ds=self.ds 69 | ) 70 | 71 | if self.is_spike: 72 | data = np.int64(data > 0) 73 | 74 | target_idx = self.train_target[idx] 75 | label = np.zeros((20)) 76 | label[target_idx] = 1.0 77 | 78 | return data, label 79 | else: 80 | data = self.test_data[idx] 81 | time_all = data[-1][0] - data[0][0] 82 | 83 | data = sample_test( 84 | data=data, 85 | dt=self.dt, 86 | T=self.T, 87 | clips=self.clips, 88 | ) 89 | 90 | target_idx = self.test_target[idx] 91 | label = np.zeros((20)) 92 | label[target_idx] = 1.0 93 | 94 | data_temp = [] 95 | target_temp = [] 96 | for i in range(self.clips): 97 | temp = chunk_evs_pol_dvs_gait( 98 | data=data[i], 99 | dt=self.dt, 100 | T=self.T, 101 | ds=self.ds 102 | ) 103 | 104 | if self.is_spike: 105 | temp = np.int64(temp > 0) 106 | 107 | data_temp.append(temp) 108 | 109 | target_temp.append(label) 110 | 111 | data = np.array(data_temp) 112 | target = np.array(target_temp) 113 | 114 | return data, target 115 | 116 | 117 | def sample_train( 118 | data, 119 | T=60, 120 | dt=1000, 121 | is_train_Enhanced=False 122 | ): 123 | tbegin = data[:, 0][0] 124 | tend = np.maximum(0, data[:, 0][-1] - T * dt) 125 | 126 | start_time = random.randint(tbegin, tend) if is_train_Enhanced else tbegin 127 | 128 | tmad = get_tmad_slice( 129 | data[:, 0], 130 | data[:, 1:4], 131 | start_time, 132 | T * dt 133 | ) 134 | if len(tmad) == 0: 135 | return tmad 136 | tmad[:, 0] -= tmad[0, 0] 137 | return tmad 138 | 139 | 140 | def sample_test( 141 | data, 142 | T=60, 143 | clips=10, 144 | dt=1000 145 | ): 146 | tbegin = data[:, 0][0] 147 | tend = np.maximum(0, data[:, 0][-1]) 148 | 149 | tmad = get_tmad_slice( 150 | data[:, 0], 151 | data[:, 1:4], 152 | tbegin, 153 | tend - tbegin 154 | ) 155 | # 初试从零开始 156 | tmad[:, 0] -= tmad[0, 0] 157 | 158 | start_time = tmad[0, 0] 159 | end_time = tmad[-1, 0] 160 | 161 | start_point = [] 162 | if clips * T * dt - (end_time - start_time) > 0: 163 | overlap = int( 164 | np.floor((clips * T * dt - (end_time - start_time)) / clips)) 165 | for j in range(clips): 166 | start_point.append(j * (T * dt - overlap)) 167 | if start_point[-1] + T * dt > end_time: 168 | diff = start_point[-1] + T * dt - end_time 169 | start_point[-1] = start_point[-1] - diff 170 | else: 171 | overlap = int( 172 | np.floor(((end_time - start_time) - clips * T * dt) / clips)) 173 | for j in range(clips): 174 | start_point.append(j * (T * dt + overlap)) 175 | if start_point[-1] + T * dt > end_time: 176 | diff = start_point[-1] + T * dt - end_time 177 | start_point[-1] = start_point[-1] - diff 178 | 179 | temp = [] 180 | for start in start_point: 181 | idx_beg = find_first(tmad[:, 0], start) 182 | idx_end = find_first(tmad[:, 0][idx_beg:], start + T * dt) + idx_beg 183 | temp.append(tmad[idx_beg:idx_end]) 184 | 185 | return temp 186 | 187 | 188 | def create_datasets( 189 | root=None, 190 | train=True, 191 | chunk_size_train=60, 192 | chunk_size_test=60, 193 | ds=4, 194 | dt=1000, 195 | transform_train=None, 196 | transform_test=None, 197 | target_transform_train=None, 198 | target_transform_test=None, 199 | n_events_attention=None, 200 | clip=10, 201 | is_train_Enhanced=False, 202 | is_spike=False, 203 | interval_scaling=False, 204 | T=16, 205 | ): 206 | if isinstance(ds, int): 207 | ds = [ds, ds] 208 | 209 | if n_events_attention is None: 210 | def default_transform(): 211 | return Compose([ 212 | ToTensor() 213 | ]) 214 | else: 215 | def default_transform(): 216 | return Compose([ 217 | ToTensor() 218 | ]) 219 | 220 | if transform_train is None: 221 | transform_train = default_transform() 222 | if transform_test is None: 223 | transform_test = default_transform() 224 | 225 | if target_transform_train is None: 226 | target_transform_train = Compose( 227 | [Repeat(chunk_size_train), toOneHot(11)]) 228 | if target_transform_test is None: 229 | target_transform_test = Compose( 230 | [Repeat(chunk_size_test), toOneHot(11)]) 231 | 232 | dataset = DVS128GaitDataset( 233 | root, 234 | dt=dt, 235 | T=T, 236 | train=train, 237 | is_train_Enhanced=is_train_Enhanced, 238 | clips=clip, 239 | is_spike=is_spike, 240 | ds=ds, 241 | ) 242 | return dataset 243 | -------------------------------------------------------------------------------- /datasets/gesture/dataloaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torchvision.transforms import Compose 8 | 9 | from datasets.transforms import find_first, Repeat, toOneHot, ToTensor 10 | from datasets.events_timeslices import chunk_evs_pol_dvs_gesture, get_tmad_slice 11 | 12 | 13 | class DVSGestureDataset(Dataset): 14 | def __init__( 15 | self, 16 | root, 17 | train=True, 18 | transform=None, 19 | target_transform=None, 20 | clip=10, 21 | is_train_Enhanced=False, 22 | dt=1000, 23 | size=[2, 32, 32], 24 | ds=4, 25 | is_spike=False, 26 | interval_scaling=False, 27 | T=16, 28 | ): 29 | super(DVSGestureDataset, self).__init__() 30 | 31 | # self.n = 0 32 | self.root = root 33 | self.train = train 34 | self.chunk_size = T 35 | self.clip = clip 36 | self.is_train_Enhanced = is_train_Enhanced 37 | self.dt = dt 38 | self.transform = transform 39 | self.target_transform = target_transform 40 | self.size = size 41 | self.ds = ds 42 | 43 | self.is_spike = is_spike 44 | self.interval_scaling = interval_scaling 45 | 46 | if train: 47 | root_train = os.path.join(self.root, 'train') 48 | for _, _, self.files_train in os.walk(root_train): 49 | pass 50 | self.n = len(self.files_train) 51 | else: 52 | root_test = os.path.join(self.root, 'test') 53 | for _, _, self.files_test in os.walk(root_test): 54 | pass 55 | self.n = len(self.files_test) 56 | 57 | def __len__(self): 58 | return self.n 59 | 60 | def __getitem__(self, idx): 61 | # Important to open and close in getitem to enable num_workers > 0 62 | if self.train: 63 | # assert idx < 1175 64 | root_test = os.path.join(self.root, 'train') 65 | 66 | with h5py.File(root_test + os.sep + self.files_train[idx], 'r', swmr=True, libver="latest") as f: 67 | data, target = sample_train( 68 | f, 69 | T=self.chunk_size, 70 | is_train_Enhanced=self.is_train_Enhanced, 71 | dt=self.dt 72 | ) 73 | 74 | # data = event_drop(data, resolution=[128 // self.ds[0], 128 // self.ds[1]]) 75 | data = chunk_evs_pol_dvs_gesture( 76 | data=data, 77 | dt=self.dt, 78 | T=self.chunk_size, 79 | size=self.size, 80 | ds=self.ds 81 | ) 82 | 83 | if self.is_spike: 84 | data = np.int64(data > 0) 85 | if self.interval_scaling: 86 | data = data / data.max() 87 | 88 | if self.transform is not None: 89 | data = self.transform(data) 90 | 91 | if self.target_transform is not None: 92 | target = self.target_transform(target) 93 | 94 | return data, target 95 | else: 96 | # assert idx < 288 97 | root_test = os.path.join(self.root, 'test') 98 | 99 | with h5py.File(root_test + os.sep + self.files_test[idx], 'r', swmr=True, libver="latest") as f: 100 | data, target = sample_test( 101 | f, 102 | T=self.chunk_size, 103 | clip=self.clip, 104 | dt=self.dt 105 | ) 106 | 107 | data_temp = [] 108 | target_temp = [] 109 | for i in range(self.clip): 110 | 111 | if self.transform is not None: 112 | temp = chunk_evs_pol_dvs_gesture( 113 | data=data[i], 114 | dt=self.dt, 115 | T=self.chunk_size, 116 | size=self.size, 117 | ds=self.ds 118 | ) 119 | 120 | if self.is_spike: 121 | temp = np.int64(temp > 0) 122 | if self.interval_scaling: 123 | temp = temp / temp.max() 124 | 125 | data_temp.append(self.transform(temp)) 126 | 127 | if self.target_transform is not None: 128 | target_temp.append(self.target_transform(target)) 129 | 130 | data = torch.stack(data_temp) 131 | target = torch.stack(target_temp) 132 | 133 | return data, target 134 | 135 | 136 | def sample_train( 137 | hdf5_file, 138 | T=60, 139 | dt=1000, 140 | is_train_Enhanced=False 141 | ): 142 | label = hdf5_file['labels'][()] 143 | 144 | tbegin = hdf5_file['times'][0] 145 | tend = np.maximum(0, hdf5_file['times'][-1] - T * dt) 146 | 147 | start_time = np.random.randint(tbegin, tend) if is_train_Enhanced else 0 148 | 149 | tmad = get_tmad_slice( 150 | hdf5_file['times'][()], 151 | hdf5_file['addrs'][()], 152 | start_time, 153 | T * dt 154 | ) 155 | tmad[:, 0] -= tmad[0, 0] 156 | return tmad[:, [0, 3, 1, 2]], label 157 | 158 | 159 | def sample_test( 160 | hdf5_file, 161 | T=60, 162 | clip=10, 163 | dt=1000 164 | ): 165 | label = hdf5_file['labels'][()] 166 | 167 | tbegin = hdf5_file['times'][0] 168 | tend = np.maximum(0, hdf5_file['times'][-1]) 169 | 170 | tmad = get_tmad_slice( 171 | hdf5_file['times'][()], 172 | hdf5_file['addrs'][()], 173 | tbegin, 174 | tend - tbegin 175 | ) 176 | tmad[:, 0] -= tmad[0, 0] 177 | 178 | start_time = tmad[0, 0] 179 | end_time = tmad[-1, 0] 180 | 181 | start_point = [] 182 | if clip * T * dt - (end_time - start_time) > 0: 183 | overlap = int( 184 | np.floor((clip * T * dt - (end_time - start_time)) / clip)) 185 | for j in range(clip): 186 | start_point.append(j * (T * dt - overlap)) 187 | if start_point[-1] + T * dt > end_time: 188 | diff = start_point[-1] + T * dt - end_time 189 | start_point[-1] = start_point[-1] - diff 190 | else: 191 | overlap = int( 192 | np.floor(((end_time - start_time) - clip * T * dt) / clip)) 193 | for j in range(clip): 194 | start_point.append(j * (T * dt + overlap)) 195 | if start_point[-1] + T * dt > end_time: 196 | diff = start_point[-1] + T * dt - end_time 197 | start_point[-1] = start_point[-1] - diff 198 | 199 | temp = [] 200 | for start in start_point: 201 | idx_beg = find_first(tmad[:, 0], start) 202 | idx_end = find_first(tmad[:, 0][idx_beg:], start + T * dt) + idx_beg 203 | temp.append(tmad[idx_beg:idx_end][:, [0, 3, 1, 2]]) 204 | 205 | return temp, label 206 | 207 | 208 | def create_datasets( 209 | root=None, 210 | train=True, 211 | ds=4, 212 | dt=1000, 213 | transform=None, 214 | target_transform=None, 215 | n_events_attention=None, 216 | clip=10, 217 | is_train_Enhanced=False, 218 | is_spike = False, 219 | interval_scaling = False, 220 | T=16, 221 | ): 222 | if isinstance(ds, int): 223 | ds = [ds, ds] 224 | 225 | size = [2, 128 // ds[0], 128 // ds[1]] 226 | 227 | if n_events_attention is None: 228 | def default_transform(): 229 | return Compose([ 230 | ToTensor() 231 | ]) 232 | else: 233 | def default_transform(): 234 | return Compose([ 235 | ToTensor() 236 | ]) 237 | 238 | if transform is None: 239 | transform = default_transform() 240 | 241 | if target_transform is None: 242 | target_transform = Compose([ 243 | Repeat(T), toOneHot(11) 244 | ]) 245 | 246 | dataset = DVSGestureDataset( 247 | root, 248 | train=train, 249 | transform=transform, 250 | target_transform=target_transform, 251 | is_train_Enhanced=is_train_Enhanced, 252 | dt=dt, 253 | size=size, 254 | ds=ds, 255 | is_spike=is_spike, 256 | interval_scaling=interval_scaling, 257 | T=T, 258 | clip=clip, 259 | ) 260 | return dataset 261 | -------------------------------------------------------------------------------- /datasets/recogition/dataloaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | from datasets.transforms import find_first, Repeat, toOneHot, ToTensor 8 | from datasets.events_timeslices import chunk_evs_pol_dvs_gesture, get_tmad_slice 9 | from torchvision.transforms import Compose, RandomCrop, CenterCrop 10 | 11 | 12 | mapping = { 13 | 'arm crossing': 0, 14 | 'turning around': 1, 15 | 'walking': 2, 16 | 'picking up': 3, 17 | 'throwing': 4, 18 | 'get-up': 5, 19 | 'jumping': 6, 20 | 'kicking': 7, 21 | 'sit-down': 8, 22 | 'waving': 9, 23 | } 24 | 25 | 26 | class DVSRecognitionDataset_Recogition(Dataset): 27 | def __init__( 28 | self, 29 | root, 30 | train=True, 31 | transform=None, 32 | target_transform=None, 33 | clip=10, 34 | is_train_Enhanced=False, 35 | dt=1000, 36 | size=[2, 32, 32], 37 | ds=4, 38 | is_spike=False, 39 | interval_scaling=False, 40 | T=16, 41 | ): 42 | super(DVSRecognitionDataset_Recogition, self).__init__() 43 | 44 | self.n = 0 45 | self.root = root 46 | self.train = train 47 | self.chunk_size = T 48 | self.clip = clip 49 | self.is_train_Enhanced = is_train_Enhanced 50 | self.dt = dt 51 | self.transform = transform 52 | self.target_transform = target_transform 53 | self.size = size 54 | self.ds = ds 55 | 56 | self.is_spike = is_spike 57 | self.interval_scaling = interval_scaling 58 | 59 | if train: 60 | root_train = os.path.join(self.root, 'train') 61 | self.files_train=[] 62 | for self.now_file, self.label_train_one, self.files_train_one in os.walk(root_train): 63 | for now_data_file in self.files_train_one: 64 | if now_data_file[-3:]=='npy': 65 | self.files_train.append(os.path.join(self.now_file,now_data_file)) 66 | self.n = len(self.files_train) 67 | else: 68 | root_test = os.path.join(self.root, 'test') 69 | self.files_test = [] 70 | for self.now_file, self.label_train_one, self.files_train_one in os.walk(root_test): 71 | for now_data_file in self.files_train_one: 72 | if now_data_file[-3:] == 'npy': 73 | self.files_test.append(os.path.join(self.now_file,now_data_file)) 74 | self.n = len(self.files_test) 75 | 76 | def __len__(self): 77 | return self.n 78 | 79 | def __getitem__(self, idx): 80 | 81 | # Important to open and close in getitem to enable num_workers>0 82 | 83 | if self.train: 84 | root_test = os.path.join(self.root, 'train') 85 | 86 | file_path = self.files_train[idx] 87 | data, target, time_all = sample_train( 88 | file_path, 89 | T=self.chunk_size, 90 | is_train_Enhanced=self.is_train_Enhanced, 91 | dt=self.dt, 92 | ) 93 | 94 | data = chunk_evs_pol_dvs_gesture( 95 | data=data, 96 | dt=self.dt, 97 | T=self.chunk_size, 98 | size=self.size, 99 | ds=self.ds 100 | ) 101 | if self.is_spike: 102 | data = np.int64(data>0) 103 | if self.interval_scaling: 104 | data = data / data.max() 105 | 106 | if self.transform is not None: 107 | data = self.transform(data) 108 | 109 | if self.target_transform is not None: 110 | target = self.target_transform(target) 111 | 112 | return data, target 113 | else: 114 | file_path = self.files_test[idx] 115 | data, target, time_all = sample_test( 116 | file_path, 117 | T=self.chunk_size, 118 | clip=self.clip, 119 | dt=self.dt 120 | ) 121 | 122 | data_temp = [] 123 | target_temp = [] 124 | for i in range(self.clip): 125 | 126 | if self.transform is not None: 127 | temp = chunk_evs_pol_dvs_gesture( 128 | data=data[i], 129 | dt=self.dt, 130 | T=self.chunk_size, 131 | size=self.size, 132 | ds=self.ds 133 | ) 134 | 135 | if self.is_spike: 136 | temp = np.int64(temp>0) 137 | 138 | if self.interval_scaling: 139 | temp = temp / temp.max() 140 | 141 | data_temp.append(self.transform(temp)) 142 | 143 | if self.target_transform is not None: 144 | target_temp.append(self.target_transform(target)) 145 | 146 | data = torch.stack(data_temp) 147 | target = torch.stack(target_temp) 148 | 149 | return data, target 150 | 151 | def sample_train( 152 | hdf5_file, 153 | T=60, 154 | dt=1000, 155 | is_train_Enhanced=False, 156 | ): 157 | label = mapping[hdf5_file.split('/')[-2]] 158 | data_dvs = np.load(hdf5_file) 159 | time_all = data_dvs[0][-1] - data_dvs[0][0] 160 | 161 | tbegin = data_dvs[0][-1] 162 | if data_dvs[0][-1] - tbegin < T*dt: 163 | start_time = tbegin 164 | else: 165 | tend = np.maximum(0, data_dvs[0][-1] - T * dt) 166 | # print('tend:',tend,'tbegin:',tbegin,'hdf5_file:',hdf5_file) 167 | try: 168 | start_time = np.random.randint(tbegin, tend) if is_train_Enhanced else 0 169 | except: 170 | print() 171 | 172 | tmad = get_tmad_slice( 173 | data_dvs[0], 174 | data_dvs[1:,:].T, 175 | start_time, 176 | T * dt 177 | ) 178 | try: 179 | tmad[:, 0] -= tmad[0, 0] 180 | except Exception as e: 181 | pass 182 | return tmad[:, [0, 3, 1, 2]], label, time_all 183 | 184 | 185 | def sample_test( 186 | hdf5_file, 187 | T=60, 188 | clip=10, 189 | dt=1000 190 | ): 191 | label = mapping[hdf5_file.split('/')[-2]] 192 | data_dvs = np.load(hdf5_file) 193 | tbegin = data_dvs[0][0] 194 | time_all = data_dvs[0][-1] - data_dvs[0][0] 195 | tend = np.maximum(0, data_dvs[0][-1]) 196 | 197 | tmad = get_tmad_slice( 198 | data_dvs[0], 199 | data_dvs[1:, :].T, 200 | tbegin, 201 | tend-tbegin 202 | ) 203 | tmad[:, 0] -= tmad[0, 0] 204 | 205 | start_time = tmad[0, 0] 206 | end_time = tmad[-1, 0] 207 | 208 | start_point = [] 209 | if clip * T * dt - (end_time - start_time) > 0: 210 | overlap = int( 211 | np.floor((clip * T * dt - (end_time - start_time)) / clip)) 212 | for j in range(clip): 213 | start_point.append(j * (T * dt - overlap)) 214 | if start_point[-1] + T * dt > end_time: 215 | diff = start_point[-1] + T * dt - end_time 216 | start_point[-1] = start_point[-1] - diff 217 | else: 218 | overlap = int( 219 | np.floor(((end_time - start_time) - clip * T * dt) / clip)) 220 | for j in range(clip): 221 | start_point.append(j * (T * dt + overlap)) 222 | if start_point[-1] + T * dt > end_time: 223 | diff = start_point[-1] + T * dt - end_time 224 | start_point[-1] = start_point[-1] - diff 225 | 226 | temp = [] 227 | for start in start_point: 228 | idx_beg = find_first(tmad[:, 0], start) 229 | idx_end = find_first(tmad[:, 0][idx_beg:], start + T * dt) + idx_beg 230 | temp.append(tmad[idx_beg:idx_end][:, [0, 3, 1, 2]]) 231 | 232 | return temp, label, time_all 233 | 234 | 235 | def create_datasets( 236 | root=None, 237 | train=True, 238 | ds=4, 239 | dt=1000, 240 | transform=None, 241 | target_transform=None, 242 | n_events_attention=None, 243 | clip=10, 244 | is_train_Enhanced=False, 245 | is_spike = False, 246 | interval_scaling = False, 247 | T=16, 248 | ): 249 | if isinstance(ds, int): 250 | ds = [ds, ds] 251 | 252 | size = [2, math.ceil(346 / ds[0]), 260 // ds[1]] 253 | 254 | if n_events_attention is None: 255 | def default_transform(): 256 | return Compose([ 257 | ToTensor(), 258 | # Resize( 259 | # size=(128, 128), # (height, width) 260 | # interpolation=2 # 插值方法,一般保持默认就好 261 | # ) 262 | # CenterCrop(128) 263 | ]) 264 | else: 265 | def default_transform(): 266 | return Compose([ 267 | ToTensor(), 268 | # Resize( 269 | # size=(128, 128), # (height, width) 270 | # interpolation=2 # 插值方法,一般保持默认就好 271 | # ) 272 | # CenterCrop(128) 273 | ]) 274 | 275 | if transform is None: 276 | transform = default_transform() 277 | 278 | if target_transform is None: 279 | target_transform = Compose([ 280 | Repeat(T), toOneHot(10) 281 | ]) 282 | 283 | dataset = DVSRecognitionDataset_Recogition( 284 | root, 285 | train=train, 286 | transform=transform, 287 | target_transform=target_transform, 288 | is_train_Enhanced=is_train_Enhanced, 289 | dt=dt, 290 | size=size, 291 | ds=ds, 292 | is_spike=is_spike, 293 | interval_scaling=interval_scaling, 294 | T=T, 295 | clip=clip, 296 | ) 297 | return dataset 298 | 299 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch, bisect 4 | 5 | 6 | def find_first(a, tgt): 7 | return bisect.bisect_left(a, tgt) 8 | 9 | 10 | class toOneHot(object): 11 | def __init__(self, num_classes): 12 | self.num_classes = num_classes 13 | 14 | def __call__(self, integers): 15 | y_onehot = torch.FloatTensor(integers.shape[0], self.num_classes) 16 | y_onehot.zero_() 17 | return y_onehot.scatter_(1, torch.LongTensor(integers), 1) 18 | 19 | 20 | class Crop(object): 21 | def __init__(self, low_crop, high_crop): 22 | """ 23 | Crop all dimensions 24 | """ 25 | self.low = low_crop 26 | self.high = high_crop 27 | 28 | def __call__(self, tmad): 29 | idx = np.where(np.any(tmad > self.high, axis=1)) 30 | tmad = np.delete(tmad, idx, 0) 31 | idx = np.where(np.any(tmad < self.high, axis=1)) 32 | tmad = np.delete(tmad, idx, 0) 33 | return tmad 34 | 35 | def __repr__(self): 36 | return self.__class__.__name__ + "()" 37 | 38 | 39 | class CropDims(object): 40 | def __init__(self, low_crop, high_crop, dims): 41 | self.low_crop = low_crop 42 | self.high_crop = high_crop 43 | self.dims = dims 44 | 45 | def __call__(self, tmad): 46 | for i, d in enumerate(self.dims): 47 | idx = np.where(tmad[:, d] >= self.high_crop[i]) 48 | tmad = np.delete(tmad, idx, 0) 49 | idx = np.where(tmad[:, d] < self.low_crop[i]) 50 | tmad = np.delete(tmad, idx, 0) 51 | # Normalize 52 | tmad[:, d] = tmad[:, d] - self.low_crop[i] 53 | return tmad 54 | 55 | def __repr__(self): 56 | return self.__class__.__name__ + "()" 57 | 58 | 59 | class Attention(object): 60 | def __init__(self, n_attention_events, size): 61 | """ 62 | Crop around the median event in the last n_events. 63 | """ 64 | self.att_shape = np.array(size[1:], dtype=np.int64) 65 | self.n_att_events = n_attention_events 66 | 67 | def __call__(self, tmad): 68 | df = pd.DataFrame(tmad, columns=["t", "p", "x", "y"]) 69 | # compute centroid in x and y 70 | centroids = ( 71 | df.loc[:, ["x", "y"]] 72 | .rolling(window=self.n_att_events, min_periods=1) 73 | .median() 74 | .astype(int) 75 | ) 76 | # re-address (translate) events with respect to centroid corner 77 | df.loc[:, ["x", "y"]] -= centroids - self.att_shape // 2 78 | # remove out of range events 79 | df = df.loc[ 80 | (df.x >= 0) 81 | & (df.x < self.att_shape[1]) 82 | & (df.y >= 0) 83 | & (df.y < self.att_shape[0]) 84 | ] 85 | return df.to_numpy() 86 | 87 | def __repr__(self): 88 | return self.__class__.__name__ + "()" 89 | 90 | 91 | class ToChannelHeightWidth(object): 92 | def __call__(self, tmad): 93 | n = tmad.shape[1] 94 | if n == 2: 95 | o = np.zeros(tmad.shape[0], dtype=tmad.dtype) 96 | return np.column_stack([tmad, o, o]) 97 | 98 | elif n == 4: 99 | return tmad 100 | 101 | else: 102 | raise TypeError( 103 | "Wrong number of dimensions. Found {0}, expected 1 or 3".format(n - 1) 104 | ) 105 | 106 | def __repr__(self): 107 | return self.__class__.__name__ + "()" 108 | 109 | 110 | class ToCountFrame(object): 111 | """Convert Address Events to Binary tensor. 112 | 113 | Converts a numpy.ndarray (T x H x W x C) to a torch.FloatTensor of shape (T x C x H x W) in the range [0., 1., ...] 114 | """ 115 | 116 | def __init__(self, T=500, size=[2, 32, 32]): 117 | self.T = T 118 | self.size = size 119 | 120 | def __call__(self, tmad): 121 | times = tmad[:, 0] 122 | addrs = tmad[:, 1:] 123 | 124 | ts = range(0, self.T) 125 | chunks = np.zeros([len(ts)] + self.size, dtype="int8") 126 | idx_start = 0 127 | idx_end = 0 128 | for i, t in enumerate(ts): 129 | idx_end += find_first(times[idx_end:], t) 130 | if idx_end > idx_start: 131 | ee = addrs[idx_start:idx_end] 132 | i_pol_x_y = (i, ee[:, 0], ee[:, 1], ee[:, 2]) 133 | np.add.at(chunks, i_pol_x_y, 1) 134 | idx_start = idx_end 135 | return chunks 136 | 137 | def __repr__(self): 138 | return self.__class__.__name__ + "()" 139 | 140 | 141 | class Repeat(object): 142 | """ 143 | Replicate np.array (C) as (n_repeat X C). This is useful to transform sample labels into sequences 144 | """ 145 | 146 | def __init__(self, n_repeat): 147 | self.n_repeat = n_repeat 148 | 149 | def __call__(self, target): 150 | return np.tile(np.expand_dims(target, 0), [self.n_repeat, 1]) 151 | 152 | def __repr__(self): 153 | return self.__class__.__name__ + "()" 154 | 155 | 156 | class ToTensor(object): 157 | """Convert a ``numpy.ndarray`` to tensor. 158 | 159 | Converts a numpy.ndarray (T x H x W x C) to a torch.FloatTensor of shape (T X H x W x C) 160 | """ 161 | 162 | def __call__(self, frame): 163 | """ 164 | Args: 165 | frame (numpy.ndarray): numpy array of frames 166 | 167 | Returns: 168 | Tensor: Converted data. 169 | """ 170 | return torch.FloatTensor(frame) 171 | 172 | def __repr__(self): 173 | return self.__class__.__name__ + "()" 174 | -------------------------------------------------------------------------------- /final.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 train_test.py -arch A -T 60 -attention ASA -dataset gesture -data_path /DVSGesture_data -batch_size_train 32 -batch_size_test 4 -clip 10 -num_epochs 200 -seed 0 -ps 00111 -lam 0.5 -reduction 4 2 | CUDA_VISIBLE_DEVICES=0 python3 train_test.py -arch A -T 60 -attention ASA -dataset gait_day -num_classes 20 -data_path /dvs-gait/npy -batch_size_train 32 -batch_size_test 4 -clip 10 -num_epochs 200 -seed 0 -ps 00111 -lam 0.5 -weight_decay 1e-4 -lr 1e-4 -reduction 4 3 | CUDA_VISIBLE_DEVICES=0 python3 train_test.py -arch D -T 60 -attention ASA -dataset gait_night -num_classes 20 -data_path /dvs-gait/dvs-gait-night/DVS128_Gait_Night_npy -batch_size_train 32 -batch_size_test 8 -clip 10 -num_epochs 100 -seed 0 -ps 00111 -lam 0.5 -reduction 8 4 | CUDA_VISIBLE_DEVICES=0 python3 train_test.py -arch B -T 20 -attention ASA -dataset gesture_sj -data_path /DVS_Gesture_dataset -batch_size_train 8 -batch_size_test 4 -clip 1 -num_epochs 200 -seed 0 -lam 0.5 -lr 1e-3 -ds 1 -weight_decay 0 -dt 1 -reduction 4 5 | CUDA_VISIBLE_DEVICES=0 python3 train_test.py -arch B -T 36 -attention ASA -dataset recogition -data_path /Recognition -batch_size_train 4 -batch_size_test 1 -clip 10 -num_epochs 200 -seed 0 -lam 0.5 -lr 1e-4 -ds 2 -weight_decay 0 -dt 120 -num_classes 10 -reduction 4 6 | CUDA_VISIBLE_DEVICES=0 python3 train_test.py -arch B -T 36 -attention ASA -dataset action -data_path /DaliyAction/ -batch_size_train 4 -batch_size_test 1 -clip 10 -num_epochs 100 -seed 0 -lam 0.5 -lr 1e-3 -ds 2 -weight_decay 0 -dt 120 -num_classes 10 -reduction 4 -------------------------------------------------------------------------------- /firing_num.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from spikingjelly.clock_driven import functional 8 | 9 | from model import create_net 10 | from dataset import create_dataset 11 | from config import parser 12 | from utils import set_seed 13 | 14 | 15 | @torch.no_grad() 16 | def firing_num(args): 17 | args.model.load_state_dict(torch.load(args.modelPath)) 18 | args.model.eval() 19 | 20 | all_idx = 0 21 | flag = True 22 | 23 | for n_iter, (input, labels) in enumerate(args.test_loader): 24 | functional.reset_net(args.model) 25 | 26 | b = input.size()[0] 27 | 28 | if "sj" not in args.dataset: 29 | input = input.reshape( 30 | b * args.clip, 31 | input.size()[2], 32 | input.size()[3], 33 | input.size()[4], 34 | input.size()[5], 35 | ) 36 | input = input.float().to(args.device) 37 | if len(labels.shape) == 3: 38 | labels = labels.reshape(b * args.clip, labels.size()[2]) 39 | labels = labels.float().to(args.device) 40 | else: 41 | labels = labels.reshape( 42 | b * args.clip, labels.size()[2], labels.size()[3] 43 | ) 44 | labels = labels[:, 1, :].float().to(args.device) 45 | else: 46 | input = input.float().to(args.device) 47 | if len(labels.shape) == 2: 48 | labels = labels.float().to(args.device) 49 | elif len(labels.shape) == 1: 50 | labels = F.one_hot(labels, args.num_classes).float().to(args.device) 51 | 52 | output, firing_num = args.model(input) 53 | 54 | if flag: 55 | neural_list = [] 56 | for f in firing_num: 57 | neural_list.append(f.numel() / f.shape[0] / f.shape[1]) 58 | flag = False 59 | 60 | _, predicted = torch.max(output.data, 1) 61 | _, labelTest = torch.max(labels.data, 1) 62 | for i in range(b): 63 | predicted_clips = predicted[i * args.clip : (i + 1) * args.clip] 64 | labelTest_clips = labelTest[i * args.clip : (i + 1) * args.clip] 65 | test_clip_correct = (predicted_clips == labelTest_clips).sum().item() 66 | if test_clip_correct / args.clip > 0.5: 67 | args.test_correct += 1 68 | args.test_acc = 100.0 * float(args.test_correct) / (float(n_iter + 1) * b) 69 | print( 70 | "iteration: {}/{}\tacc: {}".format( 71 | n_iter + 1, len(args.test_loader), args.test_acc 72 | ) 73 | ) 74 | 75 | list_ = [] 76 | for firing_single in firing_num: 77 | sub_list = [] 78 | firing_single = firing_single.cpu().detach().numpy() 79 | for T_ in range(args.T): 80 | sub_list.append(np.sum(firing_single[T_, ...])) 81 | sub_list.append(torch.from_numpy(firing_single[0, ...]).numel()) 82 | list_.append(sub_list) 83 | 84 | csv = pd.DataFrame(data=list_) 85 | if not os.path.exists(args.name): 86 | os.makedirs(args.name) 87 | csv.to_csv(args.name + os.sep + str(all_idx) + ".csv") 88 | all_idx += 1 89 | 90 | flag = True 91 | spiking_all = [] 92 | for idx in range(len(args.test_loader)): 93 | name = str(idx) + ".csv" 94 | df = pd.read_csv(os.path.join(args.name, name), header=None).values 95 | 96 | if flag: 97 | for layer in range(len(df) - 1): 98 | spiking_all.append(df[layer + 1, 1:]) 99 | flag = False 100 | else: 101 | for layer in range(len(df) - 1): 102 | spiking_all[layer] = spiking_all[layer] + df[layer + 1, 1:] 103 | 104 | firing_nums = [] 105 | for nums in spiking_all: 106 | sub_list = [] 107 | num = 0 108 | total = 0 109 | for idx in range(len(nums) - 1): 110 | num += nums[idx] 111 | total += nums[len(nums) - 1] 112 | sub_list.append(nums[idx] / nums[len(nums) - 1]) 113 | sub_list.append(num / total) 114 | firing_nums.append(sub_list) 115 | 116 | csv = pd.DataFrame(data=firing_nums) 117 | csv.to_csv(args.name + ".csv") 118 | 119 | total_spike = 0.0 120 | for neural, firing_num in zip(neural_list, firing_nums): 121 | total_spike += firing_num[-1] * neural 122 | print(total_spike / sum(neural_list)) 123 | 124 | 125 | def main(i): 126 | args = parser.parse_args() 127 | 128 | set_seed(args.seed) 129 | 130 | args.test_correct = 0.0 131 | args.recordPath = args.modelPath 132 | args.im_width, args.im_height = (128 // args.ds, 128 // args.ds) 133 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 134 | args.device_ids = range(torch.cuda.device_count()) 135 | 136 | args.name = ( 137 | args.dataset 138 | + "_dt=" 139 | + str(args.dt) 140 | + "ms" 141 | + "_T=" 142 | + str(args.T) 143 | + "_attn=" 144 | + args.attention 145 | + "_lam=" 146 | + str(args.lam) 147 | + "_seed=" 148 | + str(args.seed) 149 | + "_arch=" 150 | + str(args.arch) 151 | + "_" 152 | + str(i) 153 | ) 154 | args.modelNames = args.name + ".pth" 155 | args.modelPath = os.path.join( 156 | os.path.join(args.modelPath, args.attention), args.modelNames 157 | ) 158 | args.recordNames = args.name + ".csv" 159 | 160 | print(args) 161 | 162 | create_dataset(args=args) 163 | 164 | create_net(args=args) 165 | 166 | firing_num(args=args) 167 | 168 | 169 | if __name__ == "__main__": 170 | main(0) 171 | main(1) 172 | -------------------------------------------------------------------------------- /img/abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BICLab/ASA-SNN/c68b2d0211eb481c47af9c43fc2107f057a6066d/img/abstract.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def create_net(args): 5 | if args.arch == "A": 6 | from models.Net import DVSGestureNet 7 | elif args.arch == "B": 8 | from models.Net2 import DVSGestureNet 9 | elif args.arch == "C": 10 | from models.Net3 import DVSGestureNet 11 | elif args.arch == "D": 12 | from models.Net4 import DVSGestureNet 13 | 14 | model = DVSGestureNet(args).to(args.device) 15 | 16 | args.model = model 17 | # args.model = nn.DataParallel( 18 | # args.model, 19 | # device_ids=args.device_ids 20 | # ) 21 | 22 | args.optimizer = torch.optim.Adam( 23 | args.model.parameters(), 24 | lr=args.lr, 25 | betas=args.betas, 26 | eps=args.eps, 27 | weight_decay=args.weight_decay, 28 | ) 29 | 30 | # args.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 31 | # optimizer=args.optimizer, 32 | # mode='min', 33 | # factor=0.2, 34 | # patience=5, 35 | # verbose=True 36 | # ) 37 | args.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 38 | optimizer=args.optimizer, 39 | T_0=args.num_epochs, 40 | T_mult=1, 41 | eta_min=1e-6, 42 | verbose=True, 43 | ) 44 | -------------------------------------------------------------------------------- /models/Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.module.LIF import ConvAttLIF, AttLIF 5 | 6 | 7 | class DVSGestureNet(nn.Module): 8 | def __init__( 9 | self, 10 | args, 11 | conv_configs=[ 12 | ( 13 | 2, 14 | 64, 15 | 3, 16 | ), 17 | ( 18 | 64, 19 | 128, 20 | 3, 21 | ), 22 | ( 23 | 128, 24 | 128, 25 | 3, 26 | ), 27 | ], 28 | pool_kernels_size=[1, 2, 2], 29 | attn_flags=[1, 1, 1], 30 | ): 31 | super().__init__() 32 | 33 | dropout_ps = [] 34 | for i in range(len(args.ps)): 35 | if args.ps[i] == "1": 36 | dropout_ps.append(0.1) 37 | elif args.ps[i] == "0": 38 | dropout_ps.append(0.0) 39 | 40 | class ActFun(torch.autograd.Function): 41 | def __init__(self): 42 | super(ActFun, self).__init__() 43 | 44 | @staticmethod 45 | def forward(ctx, input): 46 | ctx.save_for_backward(input) 47 | return input.ge(0.0).float() 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | (input,) = ctx.saved_tensors 52 | temp = abs(input) < args.lens 53 | return grad_output * temp.float() / (2 * args.lens) 54 | 55 | ConvAttLIFs = [] 56 | for i, (conv_config, pooling_kernel_size, attn_flag) in enumerate( 57 | zip( 58 | conv_configs, 59 | pool_kernels_size, 60 | attn_flags, 61 | ) 62 | ): 63 | if attn_flag == 0: 64 | attn = "no" 65 | elif attn_flag == 1: 66 | attn = args.attention 67 | ConvAttLIFs.append( 68 | ConvAttLIF( 69 | p=dropout_ps[i], 70 | attention=attn, 71 | inputSize=conv_config[0], 72 | hiddenSize=conv_config[1], 73 | kernel_size=conv_config[2], 74 | spikeActFun=ActFun.apply, 75 | init_method=args.init_method, 76 | useBatchNorm=True, 77 | pooling_kernel_size=pooling_kernel_size, 78 | T=args.T, 79 | pa_dict={ 80 | "alpha": args.alpha, 81 | "beta": args.beta, 82 | "Vreset": args.Vreset, 83 | "Vthres": args.Vthres, 84 | }, 85 | reduction=args.reduction, 86 | track_running_stats=args.track_running_stats, 87 | mode_select=args.mode_select, 88 | mem_act=args.mem_act, 89 | TR_model=args.TR_model, 90 | attention_before_conv=args.attention_before_conv, 91 | attention_per_time=args.attention_per_time, 92 | attention_in_spike=args.attention_in_spike, 93 | ) 94 | ) 95 | self.ConvAttLIFs = nn.ModuleList(ConvAttLIFs) 96 | for l in self.ConvAttLIFs: 97 | l.attn.lam_ = args.lam 98 | 99 | # Building Head 100 | FCs = [] 101 | cfg_fc = [conv_configs[-1][1] * 8 * 8, 256, args.num_classes] 102 | cfg_bn = [False, False] 103 | for i in range(len(cfg_fc) - 1): 104 | FCs.append( 105 | AttLIF( 106 | p=dropout_ps[i + 3], 107 | attention="no", 108 | inputSize=cfg_fc[i], 109 | hiddenSize=cfg_fc[i + 1], 110 | spikeActFun=ActFun.apply, 111 | useBatchNorm=cfg_bn[i], 112 | T=args.T, 113 | pa_dict={ 114 | "alpha": args.alpha, 115 | "beta": args.beta, 116 | "Vreset": args.Vreset, 117 | "Vthres": args.Vthres, 118 | }, 119 | reduction=args.reduction, 120 | track_running_stats=args.track_running_stats, 121 | mode_select=args.mode_select, 122 | mem_act=args.mem_act, 123 | TR_model=args.TR_model, 124 | ) 125 | ) 126 | self.FCs = nn.ModuleList(FCs) 127 | 128 | def forward(self, input, mp_collect=False): 129 | firing_nums = [] 130 | if mp_collect: 131 | mps = [] 132 | b, t, _, _, _ = input.size() 133 | outputs = input.permute(1, 0, 2, 3, 4).contiguous() 134 | 135 | for body_layer in self.ConvAttLIFs: 136 | outputs = body_layer(outputs, mp_collect=mp_collect) 137 | if mp_collect: 138 | firing_nums.append(outputs[0]) 139 | mps.append(outputs[1]) 140 | outputs = outputs[0] 141 | else: 142 | firing_nums.append(outputs) 143 | 144 | outputs = outputs.reshape(t, b, -1) 145 | 146 | for head_layer in self.FCs: 147 | outputs = head_layer(outputs, mp_collect=mp_collect) 148 | if mp_collect: 149 | firing_nums.append(outputs[0]) 150 | mps.append(outputs[1]) 151 | outputs = outputs[0] 152 | else: 153 | firing_nums.append(outputs) 154 | 155 | outputs = torch.mean(outputs.permute(1, 0, 2).contiguous(), dim=1) 156 | 157 | if mp_collect: 158 | return outputs, firing_nums, mps 159 | else: 160 | return outputs, firing_nums 161 | -------------------------------------------------------------------------------- /models/Net2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models import POOL_LAYER, ATTN_LAYER 6 | from spikingjelly.clock_driven.neuron import * 7 | from spikingjelly.clock_driven import layer, functional, surrogate 8 | 9 | # Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks https://arxiv.org/abs/2007.05785 10 | 11 | 12 | class VotingLayer(nn.Module): 13 | def __init__(self, voting_size: int = 10): 14 | super().__init__() 15 | self.voting_size = voting_size 16 | 17 | def forward(self, x: torch.Tensor): 18 | y = F.avg_pool1d(x.unsqueeze(1), self.voting_size, self.voting_size).squeeze(1) 19 | return y 20 | 21 | 22 | class ConvAttLIF(nn.Module): 23 | def __init__( 24 | self, 25 | attention, 26 | inputSize, 27 | hiddenSize, 28 | kernel_size, 29 | stride, 30 | T, 31 | pool_mode="no", 32 | track_running_stats=True, 33 | attention_before_spike=True, 34 | ): 35 | super().__init__() 36 | 37 | self.attention_before_spike = attention_before_spike 38 | 39 | self.pool = layer.SeqToANNContainer(POOL_LAYER[pool_mode]) 40 | 41 | self.conv = layer.SeqToANNContainer( 42 | nn.Conv2d( 43 | in_channels=inputSize, 44 | out_channels=hiddenSize, 45 | kernel_size=kernel_size, 46 | padding=1, 47 | stride=stride, 48 | bias=False, 49 | ) 50 | ) 51 | 52 | self.bn = layer.SeqToANNContainer( 53 | nn.BatchNorm2d(hiddenSize, track_running_stats=track_running_stats) 54 | ) 55 | 56 | self.spike = layer.MultiStepContainer( 57 | ParametricLIFNode(surrogate_function=surrogate.ATan()) 58 | ) 59 | 60 | self.attn = ATTN_LAYER[attention](T, hiddenSize) 61 | 62 | def forward(self, x): 63 | x = self.pool(x) 64 | 65 | x = self.conv(x) 66 | 67 | x = self.bn(x) 68 | 69 | if self.attention_before_spike: 70 | x = x.transpose(0, 1).contiguous() 71 | x = self.attn(x) 72 | x = x.transpose(0, 1).contiguous() 73 | 74 | x = self.spike(x) 75 | 76 | if not self.attention_before_spike: 77 | x = x.transpose(0, 1).contiguous() 78 | x = self.attn(x) 79 | x = x.transpose(0, 1).contiguous() 80 | 81 | return x 82 | 83 | 84 | class AttLIF(nn.Module): 85 | def __init__( 86 | self, 87 | in_channels, 88 | out_channels, 89 | ): 90 | super().__init__() 91 | 92 | self.fc = nn.Linear(in_channels, out_channels) 93 | self.spike = layer.MultiStepContainer( 94 | ParametricLIFNode(surrogate_function=surrogate.ATan()) 95 | ) 96 | 97 | def forward(self, x): 98 | x = functional.seq_to_ann_forward(x, self.fc) 99 | x = self.spike(x) 100 | return x 101 | 102 | 103 | class DVSGestureNet(nn.Module): 104 | def __init__( 105 | self, 106 | args, 107 | channels=128, 108 | conv_configs=[ 109 | (2, 128, 3, 1), 110 | (128, 128, 3, 1), 111 | (128, 128, 3, 1), 112 | (128, 128, 3, 1), 113 | (128, 128, 3, 1), 114 | ], 115 | pool_modes=["no", "max", "max", "max", "max"], 116 | attn_flags=[1, 1, 1, 1, 1], 117 | ): 118 | super().__init__() 119 | assert len(pool_modes) == len(conv_configs) 120 | 121 | ConvAttLIFs = [] 122 | for i, (conv_config, pool_mode, attn_flag) in enumerate( 123 | zip(conv_configs, pool_modes, attn_flags) 124 | ): 125 | ConvAttLIFs.append( 126 | nn.Sequential( 127 | ConvAttLIF( 128 | attention=args.attention, 129 | inputSize=conv_config[0], 130 | hiddenSize=conv_config[1], 131 | kernel_size=conv_config[2], 132 | stride=conv_config[3], 133 | pool_mode=pool_mode, 134 | T=args.T, 135 | track_running_stats=args.track_running_stats, 136 | ), 137 | ) 138 | ) 139 | self.pool = layer.SeqToANNContainer(POOL_LAYER["max"]) 140 | self.ConvAttLIFs = nn.ModuleList(ConvAttLIFs) 141 | 142 | FCs = [] 143 | cfg_fc = (channels * 4**2, 512, args.num_classes * 10) 144 | if args.dataset == "action" or args.dataset == "recogition": 145 | if args.ds <= 2: 146 | cfg_fc = (channels * 80 // (args.ds**2), 512, args.num_classes * 10) 147 | elif args.ds == 4: 148 | cfg_fc = (512, 512, args.num_classes * 10) 149 | 150 | for i in range(len(cfg_fc) - 1): 151 | FCs.append( 152 | nn.Sequential( 153 | layer.MultiStepDropout(0.0), AttLIF(cfg_fc[i], cfg_fc[i + 1]) 154 | ) 155 | ) 156 | self.FCs = nn.ModuleList(FCs) 157 | 158 | self.voting = VotingLayer(10) 159 | 160 | def forward(self, x: torch.Tensor): 161 | x = x.transpose(0, 1).contiguous() 162 | firing_nums = [] 163 | 164 | for body_layer in self.ConvAttLIFs: 165 | x = body_layer(x) 166 | firing_nums.append(x) 167 | 168 | x = self.pool(x) 169 | x = x.flatten(2) 170 | 171 | for head_layer in self.FCs: 172 | x = head_layer(x) 173 | firing_nums.append(x) 174 | 175 | x = functional.seq_to_ann_forward(x, self.voting) 176 | 177 | x = torch.mean(x, dim=0) 178 | 179 | return x, firing_nums 180 | -------------------------------------------------------------------------------- /models/Net3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models import POOL_LAYER, ATTN_LAYER 6 | from spikingjelly.clock_driven.neuron import * 7 | from spikingjelly.clock_driven import layer, functional, surrogate 8 | 9 | # Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks https://arxiv.org/abs/2007.05785 10 | 11 | 12 | class VotingLayer(nn.Module): 13 | def __init__(self, voting_size: int = 10): 14 | super().__init__() 15 | self.voting_size = voting_size 16 | 17 | def forward(self, x: torch.Tensor): 18 | y = F.avg_pool1d(x.unsqueeze(1), self.voting_size, self.voting_size).squeeze(1) 19 | return y 20 | 21 | 22 | class ConvAttLIF(nn.Module): 23 | def __init__( 24 | self, 25 | attention, 26 | inputSize, 27 | hiddenSize, 28 | kernel_size, 29 | stride, 30 | T, 31 | pool_mode="no", 32 | track_running_stats=True, 33 | attention_before_spike=True, 34 | ): 35 | super().__init__() 36 | 37 | self.attention_before_spike = attention_before_spike 38 | 39 | self.pool = layer.SeqToANNContainer(POOL_LAYER[pool_mode]) 40 | 41 | self.conv = layer.SeqToANNContainer( 42 | nn.Conv2d( 43 | in_channels=inputSize, 44 | out_channels=hiddenSize, 45 | kernel_size=kernel_size, 46 | padding=1, 47 | stride=stride, 48 | bias=False, 49 | ) 50 | ) 51 | 52 | self.bn = layer.SeqToANNContainer( 53 | nn.BatchNorm2d(hiddenSize, track_running_stats=track_running_stats) 54 | ) 55 | 56 | self.spike = layer.MultiStepContainer( 57 | LIFNode(surrogate_function=surrogate.ATan()) 58 | ) 59 | 60 | self.attn = ATTN_LAYER[attention](T, hiddenSize) 61 | 62 | def forward(self, x): 63 | x = self.pool(x) 64 | 65 | x = self.conv(x) 66 | 67 | x = self.bn(x) 68 | 69 | if self.attention_before_spike: 70 | x = x.transpose(0, 1).contiguous() 71 | x = self.attn(x) 72 | x = x.transpose(0, 1).contiguous() 73 | 74 | x = self.spike(x) 75 | 76 | if not self.attention_before_spike: 77 | x = x.transpose(0, 1).contiguous() 78 | x = self.attn(x) 79 | x = x.transpose(0, 1).contiguous() 80 | 81 | return x 82 | 83 | 84 | class AttLIF(nn.Module): 85 | def __init__( 86 | self, 87 | in_channels, 88 | out_channels, 89 | ): 90 | super().__init__() 91 | 92 | self.fc = nn.Linear(in_channels, out_channels) 93 | self.spike = layer.MultiStepContainer( 94 | LIFNode(surrogate_function=surrogate.ATan()) 95 | ) 96 | 97 | def forward(self, x): 98 | x = functional.seq_to_ann_forward(x, self.fc) 99 | x = self.spike(x) 100 | return x 101 | 102 | 103 | class DVSGestureNet(nn.Module): 104 | def __init__( 105 | self, 106 | args, 107 | channels=128, 108 | conv_configs=[ 109 | (2, 64, 3, 1), 110 | (64, 128, 3, 1), 111 | (128, 128, 3, 1), 112 | (128, 128, 3, 1), 113 | (128, 128, 3, 1), 114 | ], 115 | pool_modes=["no", "avg", "avg", "avg", "avg"], 116 | attn_flags=[1, 1, 1, 1, 1], 117 | ): 118 | super().__init__() 119 | assert len(pool_modes) == len(conv_configs) 120 | 121 | ConvAttLIFs = [] 122 | for i, (conv_config, pool_mode, attn_flag) in enumerate( 123 | zip(conv_configs, pool_modes, attn_flags) 124 | ): 125 | ConvAttLIFs.append( 126 | nn.Sequential( 127 | ConvAttLIF( 128 | attention=args.attention, 129 | inputSize=conv_config[0], 130 | hiddenSize=conv_config[1], 131 | kernel_size=conv_config[2], 132 | stride=conv_config[3], 133 | pool_mode=pool_mode, 134 | T=args.T, 135 | track_running_stats=args.track_running_stats, 136 | ), 137 | ) 138 | ) 139 | self.pool = layer.SeqToANNContainer(POOL_LAYER["max"]) 140 | self.ConvAttLIFs = nn.ModuleList(ConvAttLIFs) 141 | 142 | FCs = [] 143 | cfg_fc = (channels * 4**2, 512, args.num_classes * 10) 144 | if args.dataset == "action" or args.dataset == "recogition": 145 | if args.ds <= 2: 146 | cfg_fc = (channels * 80 // (args.ds**2), 512, args.num_classes * 10) 147 | elif args.ds == 4: 148 | cfg_fc = (512, 512, args.num_classes * 10) 149 | 150 | for i in range(len(cfg_fc) - 1): 151 | FCs.append( 152 | nn.Sequential( 153 | layer.MultiStepDropout(0.0), AttLIF(cfg_fc[i], cfg_fc[i + 1]) 154 | ) 155 | ) 156 | self.FCs = nn.ModuleList(FCs) 157 | 158 | self.voting = VotingLayer(10) 159 | 160 | def forward(self, x: torch.Tensor): 161 | x = x.transpose(0, 1).contiguous() 162 | firing_nums = [] 163 | 164 | for body_layer in self.ConvAttLIFs: 165 | x = body_layer(x) 166 | firing_nums.append(x) 167 | 168 | x = self.pool(x) 169 | x = x.flatten(2) 170 | 171 | for head_layer in self.FCs: 172 | x = head_layer(x) 173 | firing_nums.append(x) 174 | 175 | x = functional.seq_to_ann_forward(x, self.voting) 176 | 177 | x = torch.mean(x, dim=0) 178 | 179 | return x, firing_nums 180 | -------------------------------------------------------------------------------- /models/Net4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.module.LIF import ConvAttLIF, AttLIF 5 | 6 | 7 | class DVSGestureNet(nn.Module): 8 | def __init__( 9 | self, 10 | args, 11 | conv_configs=[ 12 | ( 13 | 2, 14 | 32, 15 | 3, 16 | ), 17 | ( 18 | 32, 19 | 32, 20 | 3, 21 | ), 22 | ( 23 | 32, 24 | 32, 25 | 3, 26 | ), 27 | ], 28 | pool_kernels_size=[1, 2, 2], 29 | attn_flags=[1, 1, 1], 30 | ): 31 | super().__init__() 32 | 33 | dropout_ps = [] 34 | for i in range(len(args.ps)): 35 | if args.ps[i] == "1": 36 | dropout_ps.append(0.5) 37 | elif args.ps[i] == "0": 38 | dropout_ps.append(0.0) 39 | 40 | class ActFun(torch.autograd.Function): 41 | def __init__(self): 42 | super(ActFun, self).__init__() 43 | 44 | @staticmethod 45 | def forward(ctx, input): 46 | ctx.save_for_backward(input) 47 | return input.ge(0.0).float() 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | (input,) = ctx.saved_tensors 52 | temp = abs(input) < args.lens 53 | return grad_output * temp.float() / (2 * args.lens) 54 | 55 | ConvAttLIFs = [] 56 | for i, (conv_config, pooling_kernel_size, attn_flag) in enumerate( 57 | zip( 58 | conv_configs, 59 | pool_kernels_size, 60 | attn_flags, 61 | ) 62 | ): 63 | if attn_flag == 0: 64 | attn = "no" 65 | elif attn_flag == 1: 66 | attn = args.attention 67 | ConvAttLIFs.append( 68 | ConvAttLIF( 69 | p=dropout_ps[i], 70 | attention=attn, 71 | inputSize=conv_config[0], 72 | hiddenSize=conv_config[1], 73 | kernel_size=conv_config[2], 74 | spikeActFun=ActFun.apply, 75 | init_method=args.init_method, 76 | useBatchNorm=True, 77 | pooling_kernel_size=pooling_kernel_size, 78 | T=args.T, 79 | pa_dict={ 80 | "alpha": args.alpha, 81 | "beta": args.beta, 82 | "Vreset": args.Vreset, 83 | "Vthres": args.Vthres, 84 | }, 85 | reduction=args.reduction, 86 | track_running_stats=args.track_running_stats, 87 | mode_select=args.mode_select, 88 | mem_act=args.mem_act, 89 | TR_model=args.TR_model, 90 | attention_before_conv=args.attention_before_conv, 91 | attention_per_time=args.attention_per_time, 92 | attention_in_spike=args.attention_in_spike, 93 | ) 94 | ) 95 | self.ConvAttLIFs = nn.ModuleList(ConvAttLIFs) 96 | for l in self.ConvAttLIFs: 97 | l.attn.lam_ = args.lam 98 | 99 | # Building Head 100 | FCs = [] 101 | cfg_fc = [conv_configs[-1][1] * 8 * 8, 256, args.num_classes] 102 | for i in range(len(cfg_fc) - 1): 103 | FCs.append( 104 | AttLIF( 105 | p=dropout_ps[i + 3], 106 | attention="no", 107 | inputSize=cfg_fc[i], 108 | hiddenSize=cfg_fc[i + 1], 109 | spikeActFun=ActFun.apply, 110 | useBatchNorm=False, 111 | T=args.T, 112 | pa_dict={ 113 | "alpha": args.alpha, 114 | "beta": args.beta, 115 | "Vreset": args.Vreset, 116 | "Vthres": args.Vthres, 117 | }, 118 | reduction=args.reduction, 119 | track_running_stats=args.track_running_stats, 120 | mode_select=args.mode_select, 121 | mem_act=args.mem_act, 122 | TR_model=args.TR_model, 123 | ) 124 | ) 125 | self.FCs = nn.ModuleList(FCs) 126 | 127 | def forward(self, input, mp_collect=False): 128 | firing_nums = [] 129 | if mp_collect: 130 | mps = [] 131 | b, t, _, _, _ = input.size() 132 | outputs = input.permute(1, 0, 2, 3, 4).contiguous() 133 | 134 | for body_layer in self.ConvAttLIFs: 135 | outputs = body_layer(outputs, mp_collect=mp_collect) 136 | if mp_collect: 137 | firing_nums.append(outputs[0]) 138 | mps.append(outputs[1]) 139 | outputs = outputs[0] 140 | else: 141 | firing_nums.append(outputs) 142 | 143 | outputs = outputs.reshape(t, b, -1) 144 | 145 | for head_layer in self.FCs: 146 | outputs = head_layer(outputs, mp_collect=mp_collect) 147 | if mp_collect: 148 | firing_nums.append(outputs[0]) 149 | mps.append(outputs[1]) 150 | outputs = outputs[0] 151 | else: 152 | firing_nums.append(outputs) 153 | 154 | outputs = torch.mean(outputs.permute(1, 0, 2).contiguous(), dim=1) 155 | 156 | if mp_collect: 157 | return outputs, firing_nums, mps 158 | else: 159 | return outputs, firing_nums 160 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.module.Attn import * 3 | 4 | POOL_LAYER = {"no": nn.Identity(), "avg": nn.AvgPool2d(2, 2), "max": nn.MaxPool2d(2, 2)} 5 | 6 | ATTN_LAYER = { 7 | "no": nn.Identity, 8 | "ASA": ASA, 9 | } 10 | -------------------------------------------------------------------------------- /models/module/Attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from einops import rearrange 4 | 5 | 6 | class ASA(nn.Module): 7 | def __init__(self, timeWindows, channels, reduction=1, dimension=5): 8 | super().__init__() 9 | self.avg_pool = nn.AdaptiveAvgPool3d((None, 1, 1)) 10 | self.max_pool = nn.AdaptiveMaxPool3d((None, 1, 1)) 11 | self.register_parameter('alpha', nn.Parameter(torch.FloatTensor([0.5]))) 12 | self.register_parameter('beta', nn.Parameter(torch.FloatTensor([0.5]))) 13 | self.lam_ = 0.5 14 | if reduction == 1: 15 | self.fc = nn.Linear(timeWindows, timeWindows, bias=False) 16 | else: 17 | self.fc = nn.Sequential( 18 | nn.Linear(timeWindows, timeWindows // reduction, bias=False), 19 | nn.ReLU(inplace=True), 20 | nn.Linear(timeWindows // reduction, timeWindows, bias=False) 21 | ) 22 | 23 | # --------------------- 24 | # SplitSpatialAttention 25 | # --------------------- 26 | self.conv1 = nn.Conv2d(2, 1, kernel_size=3, padding=1, bias=False) 27 | self.conv2 = nn.Conv2d(2, 1, kernel_size=3, padding=1, bias=False) 28 | self.active = nn.Sigmoid() 29 | 30 | def get_im_subim_channels(self, channels_im, M): 31 | _, topk = torch.topk(M, dim=2, k=channels_im) 32 | topk_ = topk.squeeze(-1).squeeze(-1) 33 | 34 | important_channels = torch.zeros_like(M.squeeze(-1).squeeze(-1)).to(M.device) 35 | 36 | for i in range(M.shape[1]): 37 | important_channels[:, i, topk_[:, i]] = 1 38 | 39 | important_channels = important_channels.unsqueeze(-1).unsqueeze(-1) 40 | return important_channels 41 | 42 | def forward(self, x): 43 | n, t, c, _, _ = x.shape 44 | tca_avg_map = self.avg_pool(x) 45 | tca_max_map = self.max_pool(x) 46 | map_add = 0.5 * (tca_avg_map + tca_max_map) + self.alpha * tca_avg_map + self.beta * tca_max_map 47 | map_add = self.fc(map_add.squeeze().transpose(1, 2).contiguous()).transpose(1, 2).contiguous() 48 | 49 | # --------------------- 50 | # SplitSpatialAttention 51 | # --------------------- 52 | # map_add shape: N, T, C, 1, 1 53 | important_channels = self.get_im_subim_channels(int(c * self.lam_), map_add) 54 | important_times = self.get_im_subim_channels(int(t * self.lam_), map_add.transpose(1, 2).contiguous()) 55 | important_times = important_times.transpose(1, 2).contiguous() 56 | 57 | important_tc = (important_channels + important_times) / 2 58 | important_tc = torch.where(important_tc == 0.5, 1, 0) 59 | subimportant_tc = 1. - important_tc 60 | 61 | important_features = important_tc * x 62 | subimportant_features = subimportant_tc * x 63 | 64 | important_features = rearrange(important_features, 'n t c h w -> n (t c) h w') 65 | subimportant_features = rearrange(subimportant_features, 'n t c h w -> n (t c) h w') 66 | 67 | im_AvgPool = torch.mean(important_features, dim=1, keepdim=True) / self.lam_ 68 | im_MaxPool, _ = torch.max(important_features, dim=1, keepdim=True) 69 | 70 | subim_AvgPool = torch.mean(subimportant_features, dim=1, keepdim=True) / (1 - self.lam_) 71 | subim_MaxPool, _ = torch.max(subimportant_features, dim=1, keepdim=True) 72 | 73 | im_x = torch.cat([im_AvgPool, im_MaxPool], dim=1) 74 | subim_x = torch.cat([subim_AvgPool, subim_MaxPool], dim=1) 75 | 76 | im_map = self.active(self.conv1(im_x)) 77 | subim_map = self.active(self.conv2(subim_x)) 78 | 79 | important_features = im_map * important_features 80 | subimportant_features = subim_map * subimportant_features 81 | 82 | htsa_out = important_features + subimportant_features 83 | htsa_out = rearrange(htsa_out, 'n (t c) h w -> n t c h w', t=t) 84 | 85 | return htsa_out 86 | -------------------------------------------------------------------------------- /models/module/LIF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from spikingjelly.clock_driven import layer 5 | 6 | from models import ATTN_LAYER 7 | from utils import paramInit 8 | 9 | 10 | class IFCell(nn.Module): 11 | def __init__( 12 | self, 13 | inputSize, 14 | hiddenSize, 15 | spikeActFun, 16 | pa_dict=None, 17 | pa_train_self=False, 18 | p=0, 19 | mode_select='spike', 20 | mem_act=torch.relu, 21 | TR_model='NTR', 22 | attention_flag='no', 23 | has_conv=False, 24 | attention_in_spike=False, 25 | ): 26 | super().__init__() 27 | self.inputSize = inputSize 28 | self.hiddenSize = hiddenSize 29 | self.spikeActFun = spikeActFun 30 | 31 | self.UseDropOut = True 32 | self.batchSize = None 33 | self.pa_dict = pa_dict 34 | self.pa_train_self = pa_train_self 35 | self.p = p 36 | self.has_conv = has_conv 37 | 38 | self.attention_flag = attention_flag 39 | self.attention_in_spike = attention_in_spike 40 | self.attn = ATTN_LAYER[attention_flag](1, hiddenSize, reduction=1) 41 | 42 | # LIAF 43 | self.TR_model = TR_model 44 | # LIF 45 | self.mode_select = mode_select 46 | self.mem_act = mem_act 47 | 48 | if not pa_train_self and pa_dict is None: 49 | pa_dict = {'alpha': 0.3, 'beta': 0., 'Vreset': 0., 'Vthres': 0.6} 50 | self.pa_dict = pa_dict 51 | 52 | if self.pa_train_self: 53 | self.alpha = nn.Parameter(torch.Tensor(1, hiddenSize)) 54 | self.beta = nn.Parameter(torch.Tensor(1, hiddenSize)) 55 | self.Vreset = nn.Parameter(torch.Tensor(1, hiddenSize)) 56 | self.Vthres = nn.Parameter(torch.Tensor(1, hiddenSize)) 57 | nn.init.uniform_(self.alpha, 0, 1) 58 | nn.init.uniform_(self.beta, 0, 1) 59 | nn.init.uniform_(self.Vreset, 0, 1) 60 | nn.init.uniform_(self.Vthres, 0, 1) 61 | else: 62 | self.alpha = self.pa_dict['alpha'] 63 | self.beta = self.pa_dict['beta'] 64 | self.Vreset = self.pa_dict['Vreset'] 65 | self.Vthres = self.pa_dict['Vthres'] 66 | 67 | if 0 < self.p < 1: 68 | self.dropout = layer.Dropout(p=self.p) 69 | 70 | def _forward_impl(self, input): 71 | input = input.reshape([self.batchSize, -1]) 72 | return input 73 | 74 | def _forward_impl_attn(self, input): 75 | input = self.attn(input.unsqueeze(1)) 76 | input = input.squeeze(1) 77 | return input 78 | 79 | def forward(self, input, init_v=None): 80 | self.batchSize = input.size()[0] 81 | 82 | if not self.has_conv: 83 | input = self._forward_impl(input) 84 | 85 | if not hasattr(self, "h"): 86 | if init_v is None: 87 | if self.has_conv: 88 | self.h = torch.zeros( 89 | self.batchSize, 90 | self.hiddenSize, 91 | input.size()[-2], 92 | input.size()[-1], 93 | device=input.device 94 | ) 95 | else: 96 | self.h = torch.zeros( 97 | self.batchSize, 98 | self.hiddenSize, 99 | device=input.device 100 | ) 101 | else: 102 | self.h = init_v.clone() 103 | 104 | if input.device != self.h.device: 105 | input = input.to(self.h.device) 106 | 107 | # Step 1: accumulate and reset,spike used as forgetting gate 108 | u = self.h + input 109 | 110 | if self.has_conv and self.attention_in_spike: 111 | u = self._forward_impl_attn(u) 112 | 113 | x_ = u - self.Vthres 114 | x = self.spikeActFun(x_) 115 | 116 | self.h = x * self.Vreset + (1 - x) * u 117 | # self.h = x * self.Vthres + (1 - x) * u 118 | self.h = self.h * self.alpha + self.beta 119 | 120 | # step 4: 121 | if self.mode_select == 'spike': 122 | x = x 123 | elif self.mode_select == 'mem': 124 | # TR 125 | if self.TR_model == 'TR': 126 | if not self.mem_act: 127 | x = x_ 128 | else: 129 | x = self.mem_act(x_) 130 | else: 131 | if not self.mem_act: 132 | x = u 133 | else: 134 | x = self.mem_act(u) 135 | 136 | if 1 > self.p > 0: 137 | x = self.dropout(x) 138 | return x 139 | 140 | def reset(self): 141 | if hasattr(self, "h"): 142 | del self.h 143 | 144 | 145 | class AttLIF(nn.Module): 146 | def __init__( 147 | self, 148 | inputSize, 149 | hiddenSize, 150 | spikeActFun, 151 | attention='T', 152 | onlyLast=False, 153 | useBatchNorm=False, 154 | init_method=None, 155 | pa_dict=None, 156 | pa_train_self=False, 157 | bias=True, 158 | reduction=1, 159 | T=1, 160 | p=0, 161 | track_running_stats=False, 162 | mode_select='spike', 163 | mem_act=torch.relu, 164 | TR_model='NTR', 165 | ): 166 | super().__init__() 167 | self.onlyLast = onlyLast 168 | self.useBatchNorm = useBatchNorm 169 | self.network = nn.Sequential() 170 | 171 | self.attention_flag = attention 172 | self.linear = layer.SeqToANNContainer(nn.Linear( 173 | in_features=inputSize, 174 | out_features=hiddenSize, 175 | bias=bias, 176 | )) 177 | if init_method is not None: 178 | paramInit(model=self.linear, method=init_method) 179 | 180 | if self.useBatchNorm: 181 | self.bn = layer.SeqToANNContainer(nn.BatchNorm1d( 182 | num_features=hiddenSize, 183 | track_running_stats=track_running_stats 184 | )) 185 | 186 | assert reduction <= hiddenSize, \ 187 | "the attn_channel should be bigger than the reduction" 188 | self.attn = ATTN_LAYER[attention](T, hiddenSize, reduction=reduction) 189 | 190 | self.dropout = layer.Dropout(p) 191 | 192 | self.spike = layer.MultiStepContainer(IFCell( 193 | inputSize, 194 | hiddenSize, 195 | spikeActFun, 196 | pa_dict=pa_dict, 197 | pa_train_self=pa_train_self, 198 | p=p, 199 | mode_select=mode_select, 200 | mem_act=mem_act, 201 | TR_model=TR_model, 202 | )) 203 | 204 | def forward(self, data, mp_collect=False): 205 | data = self.dropout(data) 206 | output = self.linear(data) 207 | if self.useBatchNorm: 208 | output = self.bn(output) 209 | if self.onlyLast: 210 | return output 211 | 212 | mp = output 213 | mp = self.attn(mp.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous() 214 | output = self.spike(mp) 215 | # print(np.sum(mp.cpu().detach().numpy()>0.3)) 216 | # print(np.sum(output.cpu().detach().numpy()>0.3)) 217 | # print() 218 | 219 | if mp_collect: 220 | return output, mp 221 | else: 222 | return output 223 | 224 | 225 | class ConvAttLIF(nn.Module): 226 | def __init__( 227 | self, 228 | inputSize, 229 | hiddenSize, 230 | kernel_size, 231 | spikeActFun, 232 | attention='no', 233 | onlyLast=False, 234 | padding=1, 235 | useBatchNorm=False, 236 | init_method=None, 237 | pa_dict=None, 238 | pa_train_self=False, 239 | reduction=1, 240 | T=1, 241 | stride=1, 242 | pooling_kernel_size=1, 243 | p=0, 244 | track_running_stats=False, 245 | mode_select='spike', 246 | mem_act=torch.relu, 247 | TR_model='NTR', 248 | attention_before_conv=False, 249 | attention_per_time=False, 250 | attention_in_spike=False, 251 | ): 252 | super().__init__() 253 | 254 | self.onlyLast = onlyLast 255 | self.attention_flag = attention 256 | self.attention_before_conv = attention_before_conv 257 | self.attention_per_time = attention_per_time 258 | 259 | self.conv = layer.SeqToANNContainer(nn.Conv2d( 260 | in_channels=inputSize, 261 | out_channels=hiddenSize, 262 | kernel_size=kernel_size, 263 | padding=padding, 264 | stride=stride, 265 | )) 266 | 267 | if init_method is not None: 268 | paramInit(model=self.conv, method=init_method) 269 | 270 | self.useBatchNorm = useBatchNorm 271 | 272 | if self.useBatchNorm: 273 | self.bn = layer.SeqToANNContainer(nn.BatchNorm2d( 274 | hiddenSize, track_running_stats=track_running_stats 275 | )) 276 | 277 | self.pooling_kernel_size = pooling_kernel_size 278 | if self.pooling_kernel_size > 1: 279 | self.pooling = layer.SeqToANNContainer( 280 | nn.AvgPool2d(kernel_size=pooling_kernel_size) 281 | ) 282 | 283 | if attention_before_conv: 284 | # assert attention != "HCSA" 285 | attn_channels = inputSize 286 | else: 287 | attn_channels = hiddenSize 288 | 289 | if attention_per_time: 290 | assert "T" not in attention 291 | T = 1 292 | 293 | assert reduction <= attn_channels, \ 294 | "the attn_channel should be bigger than the reduction" 295 | self.attn = ATTN_LAYER[attention](T, attn_channels, reduction, 5) 296 | 297 | self.dropout = layer.Dropout(p) 298 | 299 | self.spike = layer.MultiStepContainer(IFCell( 300 | inputSize=inputSize, 301 | hiddenSize=hiddenSize, 302 | spikeActFun=spikeActFun, 303 | pa_dict=pa_dict, 304 | pa_train_self=pa_train_self, 305 | p=p, 306 | mode_select=mode_select, 307 | mem_act=mem_act, 308 | TR_model=TR_model, 309 | attention_flag=self.attention_flag, 310 | has_conv=True, 311 | attention_in_spike=attention_in_spike, 312 | )) 313 | 314 | def _forward_impl_attn(self, data): 315 | data = data.permute(1, 0, 2, 3, 4).contiguous() 316 | 317 | if self.attention_per_time: 318 | for step in range(data.size(1)): 319 | out = data[:, step, :, :, :] 320 | out = self.attn(out.unsqueeze(1)) 321 | output = out.squeeze(1) 322 | 323 | if step == 0: 324 | temp = list(output.size()) 325 | temp.insert(1, list(data.size())[1]) 326 | outputsum = torch.zeros(temp) 327 | if outputsum.device != data.device: 328 | outputsum = outputsum.to(data.device) 329 | 330 | outputsum[:, step, :, :, :] = output 331 | else: 332 | outputsum = self.attn(data) 333 | 334 | data = outputsum.permute(1, 0, 2, 3, 4).contiguous() 335 | return data 336 | 337 | def forward(self, data, mp_collect=False): 338 | data = self.dropout(data) 339 | if self.attention_before_conv: 340 | data = self._forward_impl_attn(data) 341 | 342 | output = self.conv(data) 343 | if self.useBatchNorm: 344 | output = self.bn(output) 345 | 346 | if self.pooling_kernel_size > 1: 347 | output = self.pooling(output) 348 | 349 | if self.onlyLast: 350 | return output 351 | 352 | mp = output 353 | if not self.attention_before_conv: 354 | mp = self._forward_impl_attn(mp) 355 | output = self.spike(mp) 356 | 357 | if mp_collect: 358 | return output, mp 359 | else: 360 | return output 361 | -------------------------------------------------------------------------------- /models/module/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | torch 3 | torchvision 4 | torchinfo 5 | spikingjelly 6 | h5py 7 | pandas 8 | seaborn 9 | -------------------------------------------------------------------------------- /train_test.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torchinfo import summary 4 | 5 | from model import create_net 6 | from dataset import create_dataset 7 | from config import parser 8 | from apis import run 9 | from utils import set_seed, save_csv 10 | 11 | 12 | def main(): 13 | args = parser.parse_args() 14 | 15 | set_seed(args.seed) 16 | 17 | args.recordPath = args.modelPath 18 | args.im_width, args.im_height = (128 // args.ds, 128 // args.ds) 19 | if args.dataset == "action" or args.dataset == "recogition": 20 | args.im_width, args.im_height = (math.ceil(346 / args.ds), 260 // args.ds) 21 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | args.device_ids = range(torch.cuda.device_count()) 23 | args.modelPath = args.modelPath + args.attention 24 | args.name = ( 25 | args.dataset 26 | + "_dt=" 27 | + str(args.dt) 28 | + "ms" 29 | + "_T=" 30 | + str(args.T) 31 | + "_attn=" 32 | + args.attention 33 | + "_reduc=" 34 | + str(args.reduction) 35 | + "_lam=" 36 | + str(args.lam) 37 | + "_seed=" 38 | + str(args.seed) 39 | + "_arch=" 40 | + str(args.arch) 41 | ) 42 | args.modelNames = args.name + ".pth" 43 | args.recordNames = args.name + ".csv" 44 | 45 | print(args.name) 46 | 47 | create_dataset(args=args) 48 | 49 | create_net(args=args) 50 | 51 | summary( 52 | args.model, 53 | (2, args.T, args.in_channels, args.im_height, args.im_width), 54 | depth=3, 55 | ) 56 | 57 | run(args=args) 58 | 59 | print("best acc:", args.best_acc, "best_epoch:", args.best_epoch) 60 | 61 | save_csv(args=args) 62 | 63 | 64 | if __name__ == "__main__": 65 | main() 66 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=100): 11 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs. 12 | :param lr_decay_epoch: 13 | :param init_lr: 14 | :param epoch: 15 | :type optimizer: object 16 | """ 17 | 18 | if epoch % lr_decay_epoch == 0 and epoch > 1: 19 | for param_group in optimizer.param_groups: 20 | param_group["lr"] = param_group["lr"] * 0.1 21 | 22 | return optimizer 23 | 24 | 25 | # init method 26 | def paramInit(model, method="xavier"): 27 | scale = 0.05 28 | for name, w in model.named_parameters(): 29 | if "weight" in name: 30 | if method == "xavier": 31 | nn.init.xavier_normal_(w) 32 | elif method == "kaiming": 33 | nn.init.kaiming_normal_(w) 34 | else: 35 | nn.init.normal_(w) 36 | w *= scale 37 | elif "bias" in name: 38 | nn.init.constant_(w, 0) 39 | else: 40 | pass 41 | 42 | 43 | def set_seed(seed): 44 | # for reproducibility. 45 | # note that pytorch is not completely reproducible 46 | # https://pytorch.org/docs/stable/notes/randomness.html 47 | torch.backends.cudnn.enabled = True 48 | torch.backends.cudnn.benchmark = True 49 | # torch.backends.cudnn.deterministic = True 50 | random.seed(seed) 51 | os.environ["PYTHONHASHSEED"] = str(seed) 52 | np.random.seed(seed) 53 | torch.initial_seed() # dataloader multi processing 54 | torch.manual_seed(seed) 55 | torch.cuda.manual_seed(seed) 56 | torch.cuda.manual_seed_all(seed) 57 | 58 | 59 | def str2bool(v): 60 | if v == "True": 61 | return True 62 | else: 63 | return False 64 | 65 | 66 | def global_avgpool2d(x): 67 | batch_size = x.shape[0] 68 | channel_size = x.shape[1] 69 | return x.reshape(batch_size, channel_size, -1).mean(dim=2) 70 | 71 | 72 | def winner_take_all(x, sparsity_ratio): 73 | k = math.ceil(sparsity_ratio * x.shape[1]) 74 | winner_idx = x.topk(k, 1)[1] 75 | winner_mask = torch.zeros_like(x) 76 | winner_mask.scatter_(1, winner_idx, 1) 77 | x = x * winner_mask 78 | 79 | return x, winner_mask 80 | 81 | 82 | def save_csv(args): 83 | args.epoch_list.append(args.best_epoch) 84 | args.acc_test_list.append(args.best_acc) 85 | 86 | lists = [ 87 | args.loss_train_list, 88 | args.loss_test_list, 89 | args.acc_train_list, 90 | args.acc_test_list, 91 | ] 92 | csv = pd.DataFrame( 93 | data=lists, 94 | index=["Train_Loss", "Test_Loss", "Train_Accuracy", "Test_Accuracy"], 95 | columns=args.epoch_list, 96 | ) 97 | csv.index.name = "Epochs" 98 | 99 | if not os.path.exists(args.modelPath): 100 | os.makedirs(args.modelPath) 101 | csv.to_csv(args.modelPath + os.sep + args.recordNames) 102 | --------------------------------------------------------------------------------