├── src ├── config │ ├── __init__.py │ ├── config_utils.py │ ├── room_setting.ini │ ├── forest_setting.ini │ └── corridor_setting.ini ├── model │ ├── snn.py │ ├── cann.py │ └── mhnn.py ├── tools │ └── utils.py └── main │ └── main_mhnn.py ├── LICENSE └── README.md /src/config/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) -------------------------------------------------------------------------------- /src/config/config_utils.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | from os.path import isfile, join 3 | from config import ROOT_DIR 4 | 5 | def get_config(config_file, logger): 6 | config = configparser.ConfigParser() 7 | configFile = join(ROOT_DIR, config_file) 8 | logger.info(f'The path of the configure file is {configFile}') 9 | assert isfile(configFile) 10 | config.read(configFile) 11 | return config 12 | 13 | def update_config(config, updates): 14 | for k in updates: 15 | config[k] = updates[k] 16 | return config 17 | 18 | def get_bool_from_config(config, key, default_value = False): 19 | if key in config: 20 | return config[key] == 'True' 21 | else: 22 | return default_value 23 | -------------------------------------------------------------------------------- /src/config/room_setting.ini: -------------------------------------------------------------------------------- 1 | [main] 2 | cnn_arch = mobilenet_v2 3 | num_class = 100 4 | rnn_num = 128 5 | cann_num = 128 6 | lr = 2.5e-4 7 | reservoir_num = 4000 8 | sparse_lambdas = 5 9 | r = 0.15 10 | batch_size = 20 11 | spiking_threshold = 0.5 12 | 13 | num_epoch = 30 14 | num_iter = 100 15 | 16 | train_exp_idx = 1 17 | test_exp_idx = 2 18 | 19 | ann_pre_load = True 20 | snn_pre_load = False 21 | re_trained = True 22 | 23 | w_fps = 1 24 | w_gps = 1 25 | w_dvs = 1 26 | w_head = 1 27 | w_time =1 28 | 29 | dvs_seq = 3 30 | seq_len_aps = 3 31 | seq_len_gps = 3 32 | seq_len_dvs = 3 33 | seq_len_head = 3 34 | seq_len_time = 3 35 | dvs_expand = 3 36 | 37 | data_path = /data/NeuroVPR_Datasets/room/room_v 38 | snn_path = ./pretrained_model/snn_room_model.t7 39 | hnn_path = ./pretrained_model/mhnn_room_model.t7 40 | model_saving_file_name = mhnn_room_v1 41 | device_id = 7 42 | 43 | 44 | -------------------------------------------------------------------------------- /src/config/forest_setting.ini: -------------------------------------------------------------------------------- 1 | [main] 2 | cnn_arch = mobilenet_v2 3 | num_class = 100 4 | rnn_num = 128 5 | cann_num = 128 6 | lr = 2.5e-4 7 | reservoir_num = 4000 8 | sparse_lambdas = 5 9 | r = 0.15 10 | batch_size = 20 11 | spiking_threshold = 0.5 12 | 13 | num_epoch = 30 14 | num_iter = 100 15 | 16 | train_exp_idx = 1 17 | test_exp_idx = 2 18 | 19 | ann_pre_load = True 20 | snn_pre_load = False 21 | re_trained = True 22 | 23 | w_fps = 1 24 | w_gps = 1 25 | w_dvs = 1 26 | w_head = 1 27 | w_time =1 28 | 29 | dvs_seq = 3 30 | seq_len_aps = 3 31 | seq_len_gps = 3 32 | seq_len_dvs = 3 33 | seq_len_time = 3 34 | seq_len_head = 3 35 | dvs_expand = 3 36 | 37 | data_path = /data/NeuroVPR_Datasets/campus/forest/forest_v 38 | snn_path = ./pretrained_model/snn_forest_model.t7 39 | hnn_path = ./pretrained_model/mhnn_forest_model.t7 40 | model_saving_file_name = mhnn_forest_v1 41 | device_id = 7 42 | 43 | 44 | -------------------------------------------------------------------------------- /src/config/corridor_setting.ini: -------------------------------------------------------------------------------- 1 | [main] 2 | cnn_arch = mobilenet_v2 3 | num_class = 100 4 | rnn_num = 128 5 | cann_num = 128 6 | lr = 2.5e-4 7 | 8 | reservoir_num = 4000 9 | spiking_threshold = 0.5 10 | sparse_lambdas = 5 11 | r = 0.15 12 | 13 | batch_size = 20 14 | num_epoch = 30 15 | num_iter = 100 16 | 17 | train_exp_idx = 5,7 18 | test_exp_idx = 9 19 | 20 | ann_pre_load = True 21 | snn_pre_load = False 22 | re_trained = True 23 | 24 | w_fps = 1 25 | w_gps = 1 26 | w_dvs = 1 27 | w_head = 1 28 | w_time =1 29 | 30 | dvs_seq = 3 31 | seq_len_aps = 3 32 | seq_len_gps = 3 33 | seq_len_dvs = 3 34 | seq_len_time = 3 35 | seq_len_head = 3 36 | dvs_expand = 3 37 | 38 | data_path = /data/NeuroVPR_Datasets/corridor/floor3/floor3_v 39 | snn_path = ./pretrained_snn_model/snn_corridor_model.t7 40 | hnn_path = ./pretrained_snn_model/hnn_corridor_model.t7 41 | model_saving_file_name = mhnn_corridor_v1 42 | device_id = 7 43 | 44 | 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 The authors of the paper 'Brain-inspired multimodal hybrid neural network for robot place recognition' 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Brain-inspired multimodal hybrid neural network for robot place recognition 2 | 3 | ## Table of Contents 4 | 1. About 5 | 2. Structure 6 | 3. Usage 7 | 4. Datasets 8 | 5. Citation 9 | 6. License 10 | 7. Acknowledgments 11 | 8. Note 12 | 13 | ## About 14 | This library implements a brain-inspired multimodal hybrid neural network (MHNN) for robot place recognition. The MHNN encodes and integrates multimodal cues from both conventional and neuromorphic sensors. Specifically, to encode different sensory cues, we build various neural networks of spatial view cells, place cells, head direction cells, and time cells. To integrate these cues, we design a multiscale liquid state machine that can process and fuse multimodal information effectively and asynchronously by using diverse neuronal dynamics and bio-inspired inhibitory circuits. 15 | 16 | ## Structure 17 | > * src 18 | >> * model: contains the MHNN model 19 | >> * tools: contains the utils 20 | >> * config: contains the configure files 21 | >> * main: contains the demo 22 | 23 | ## Usage 24 | * Step 1, setup the running environments 25 | * Step 2, download the datasets 26 | * Step 3, download the code:
27 | git clone https://github.com/cognav/neurogpr.git 28 | * Step 4, run the demo in the folder of ‘main’ :
29 | python main_mhnn.py --config_file ../config/corridor_setting.ini 30 | 31 | ## Requirements 32 | The following libraries are needed for running the demo. 33 | * Python 3.7.4 34 | * Torch 1.11.0 35 | * Torchvision 0.12.0 36 | * Numpy 1.21.5 37 | * Scipy 1.7.1 38 | 39 | ## Datasets 40 | The datasets include four groups of data collected in different environments. 41 | * The Room, Corridor, and THU-Forest datasets are available on Zenodo (https://zenodo.org/record/7827108#.ZD_ke3bP0ds). 42 | * The public Brisbane-Event-VPR dataset is available on Zenodo (https://zenodo.org/record/4302805#.ZD8puXbP0ds). 43 | 44 | ## Citation 45 | Fangwen Yu, Yujie Wu, Songchen Ma, Mingkun Xu, Hongyi Li, Huanyu Qu, Chenhang Song, Taoyi Wang, Rong Zhao and Luping Shi. Brain-inspired multimodal hybrid neural network for robot place recognition. Science Robotics (accepted). 46 | 47 | ## License 48 | MIT License 49 | 50 | ## Acknowledgments 51 | If you have any questions, please contact us. 52 | 53 | ## Note 54 | If the parameter settings are different, the results may be not consistent. You may need to modify the settings or adjust the code properly. 55 | 56 | -------------------------------------------------------------------------------- /src/model/snn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | thresh = 0.15 6 | lens = 0.5 7 | probs = 0.5 8 | decay = 0.6 9 | 10 | cfg_cnn = [(2, 64, 3, 1, 3), 11 | (64, 128, 2, 1, 3), 12 | (128, 256, 2, 1, 3), 13 | (256, 256, 1, 1, 3), 14 | ] 15 | cfg_fc = [512, 512, 200] 16 | 17 | def mem_update(opts, x, mem, spike): 18 | mem = mem * decay * (1 - spike) + opts(x) 19 | spike = act_fun(mem - thresh) 20 | return mem, spike 21 | 22 | class ActFun(torch.autograd.Function): 23 | 24 | @staticmethod 25 | def forward(ctx, input): 26 | ctx.save_for_backward(input) 27 | return input.gt(0).float() 28 | 29 | @staticmethod 30 | def backward(ctx, grad_output): 31 | input, = ctx.saved_tensors 32 | grad_input = grad_output.clone() 33 | temp = abs(input) np.pi: # set up connectivity matrix on a ring 110 | # d = 2*np.pi-d 111 | # self.w[i, j] = self.J0 * np.exp(-0.5 * d) / (np.sqrt(2 * np.pi) * self.a) 112 | # 113 | # def data_reverse_transform(self, x): 114 | # x_min, x_max = np.min(x, axis=0), np.max(x, axis=0) 115 | # x_range = x_max - x_min 116 | # return (x - x_min) / (1e-11 + x_range) * self.z_range + self.z_min 117 | # 118 | # def get_stimulus_by_pos(self, x): 119 | # d = (x - self.centers) / (self.z_range + 1e-4) 120 | # y = np.exp(-1 * np.square(d / 0.5)) 121 | # return y 122 | # 123 | # def update(self, data, trajactory_mode=False): 124 | # seq_len, seq_dim = data.shape 125 | # u = np.zeros((self.num, seq_dim)) 126 | # u_record = np.zeros((seq_len, self.num, seq_dim)) 127 | # 128 | # for i in range(0, seq_len): 129 | # if i == 0: 130 | # cur_stimulus = self.get_stimulus_by_pos(data[0]) 131 | # else: 132 | # cur_stimulus = self.get_stimulus_by_pos(data[i - 1]) 133 | # t = np.arange(self.I_dur * i, self.I_dur * i + self.I_dur, self.dt) 134 | # cur_du = odeint(int_u, u.flatten(), t, 135 | # args=(self.w, cur_stimulus.t().numpy(), self.tau, self.Inh_inp, self.k, self.dt)) 136 | # 137 | # u = cur_du[-1].reshape(-1, seq_dim) 138 | # r1 = np.square(u) 139 | # r2 = 1.0 + 0.5 * self.k * np.sum(r1, axis=0) 140 | # u_record[i] = r1 / r2.reshape(1, -1) 141 | # 142 | # out = r1 / r2 143 | # if trajactory_mode: 144 | # out = u_record 145 | # 146 | # return out 147 | 148 | 149 | # class CANN2D(): 150 | # def __init__(self, data, length = 128, tau = 0.02, Inh_inp = -.1, I_dur = 0.4, dt = 0.05, A = 20, k = 1): 151 | # self.length = length # 2D-CANN :(length, length) 152 | # self.tau = tau # key parameters: control the pred curve 153 | # self.Inh_inp = Inh_inp # key parameters: inhibitive stimulus, key parameters 154 | # self.I_dur = I_dur 155 | # self.dt = dt 156 | # self.A = A 157 | # self.k = 1 158 | # self.t_wins = data.shape[0] + 1 159 | # self.n_units = data.shape[1] 160 | # self.z_min, self.z_max = np.min(data), np.max(data) + 1e-7 161 | # self.z_range = self.z_max - self.z_min 162 | # self.x = np.linspace(self.z_min, self.z_max, length) # sample centers 163 | # self.dx = self.z_range / length 164 | # 165 | # self.a = 2 166 | # self.J0 = 4 167 | # self.u = [] # membrane 168 | # self.u.append(np.zeros((length,length))) 169 | # self.r = [] # firing rate 170 | # self.r.append(np.zeros((length,length))) 171 | # self.conn_mat = self.make_conn() 172 | # 173 | # def dist(self,d): 174 | # v_size = np.asarray([self.z_range,self.z_range]) 175 | # return np.where(d > v_size/2, v_size -d ,d) 176 | # 177 | # def make_conn(self): 178 | # x1, x2 = np.meshgrid(self.x/self.z_max,self.x/self.z_max) 179 | # value = np.stack([x1.flatten(), x2.flatten()]).T 180 | # 181 | # dd = self.dist(np.abs(value[0] - value)) 182 | # d = np.linalg.norm(dd, ord=1, axis=1) 183 | # d = d.reshape((self.length, self.length)) 184 | # Jxx = self.J0 * np.exp(-0.5 * np.square(d/self.a))/(np.sqrt(2*np.pi)*self.a) 185 | # return Jxx 186 | # 187 | # 188 | # def data_reverse_transform(self, x): 189 | # x_min, x_max = np.min(x, axis=0),np.max(x, axis=0) 190 | # x_range = x_max - x_min 191 | # return (x - x_min)/(1e-11+x_range) * self.z_range + self.z_min 192 | # 193 | # def get_stimulus_by_pos(self, x): 194 | # # code external stimilus by Guassian coding 195 | # x1, x2 = np.meshgrid(self.x, self.x) 196 | # value = np.stack([x1.flatten(), x2.flatten()]).T 197 | # # dd = self.dist(np.abs(np.asarray(x)-value)) 198 | # dd = np.abs(np.asarray(x) - value) 199 | # d = np.linalg.norm(dd, axis=1) 200 | # d = d.reshape((self.length, self.length)) 201 | # return self.A * np.exp(-0.25*np.square(d/2)) 202 | # 203 | # 204 | # def update(self, data, trajactory_mode = False): 205 | # # convert the data input a 3D tensor: seq_len * dim * dim 206 | # seq_len, seq_dim = data.shape 207 | # u = np.zeros((self.length, self.length)) 208 | # u_record = np.zeros((seq_len, self.length, self.length)) 209 | # 210 | # for i in range(1, seq_len): 211 | # # generate the 2D inputs 212 | # cur_stimulus = self.get_stimulus_by_pos(data[i - 1]) 213 | # t = np.arange(self.I_dur * i, self.I_dur * i + self.I_dur, self.dt) 214 | # # perform updates 215 | # cur_du = odeint(int_u, u.flatten(), t, 216 | # args=(self.conn_mat, cur_stimulus.T, 217 | # self.tau, self.Inh_inp, self.k, self.dt)) 218 | # u = cur_du[-1].reshape(-1, self.length) 219 | # r1 = np.square(u) 220 | # r2 = 1.0 + self.k * np.sum(r1) 221 | # u_record[i] = r1/r2 222 | # 223 | # 224 | # # output the matrix and down-sampling 225 | # out = r1/r2 226 | # if trajactory_mode == True: 227 | # out = u_record -------------------------------------------------------------------------------- /src/tools/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch 4 | import torch.utils.model_zoo 5 | from torch.utils.data import Dataset 6 | import torchvision 7 | import math 8 | import numpy as np 9 | import random 10 | from sklearn.metrics import precision_recall_curve 11 | 12 | def least_common_multiple(num): 13 | mini = 1 14 | for i in num: 15 | mini = int(i) * int(mini) /math.gcd(int(i), mini) 16 | mini = int(mini) 17 | return mini 18 | 19 | class SequentialDataset(Dataset): 20 | def __init__(self, data_dir, exp_idx, transform, nclass=None, seq_len_aps=None, seq_len_dvs = None, 21 | seq_len_gps=None, seq_len_head = None, seq_len_time = None): 22 | self.data_dir = data_dir 23 | self.transform = transform 24 | self.num_exp = len(exp_idx) 25 | self.exp_idx = exp_idx 26 | 27 | self.total_aps = [np.sort(os.listdir(data_dir + str(idx) + '/dvs_frames')) for idx in exp_idx] 28 | self.total_dvs = [np.sort(os.listdir(data_dir + str(idx) + '/dvs_7ms_3seq')) for idx in exp_idx] 29 | self.num_imgs = [len(x) for x in self.total_aps] 30 | self.raw_pos = [np.loadtxt(data_dir + str(idx) + '/position.txt', delimiter=' ') for idx in exp_idx] 31 | self.raw_head = [np.loadtxt(data_dir + str(idx) + '/direction.txt', delimiter=' ') for idx in exp_idx] 32 | 33 | self.t_pos = [x[:, 0] for x in self.raw_pos] 34 | self.t_aps = [[float(x[:-4]) for x in y] for y in self.total_aps] 35 | self.t_dvs = [[float(x[:-4]) for x in y] for y in self.total_dvs] 36 | self.data_pos = [idx[:, 0:3]-idx[:, 0:3].min(axis=0) for idx in self.raw_pos] 37 | self.data_head = [idx[:, 0:3] - idx[:, 0:3].min(axis=0) for idx in 38 | self.raw_head] 39 | self.seq_len_aps = seq_len_aps 40 | self.seq_len_gps = seq_len_gps 41 | self.seq_len_dvs = seq_len_dvs 42 | self.seq_len_head = seq_len_head 43 | self.seq_len_time = seq_len_time 44 | self.seq_len = max(seq_len_gps, seq_len_aps) 45 | self.nclass = nclass 46 | 47 | self.lens = len(self.total_aps) - self.seq_len 48 | self.dvs_data = None 49 | self.duration = [x[-1] - x[0] for x in self.t_dvs] 50 | 51 | nums = 1e5 52 | for x in self.total_aps: 53 | if len(x) < nums: nums = len(x) 54 | for x in self.total_dvs: 55 | if len(x) < nums: nums = len(x) 56 | for x in self.raw_pos: 57 | if len(x) < nums: nums = len(x) 58 | 59 | self.lens = nums 60 | 61 | def __len__(self): 62 | return self.lens - self.seq_len * 2 63 | 64 | def __getitem__(self, idx): 65 | exp_index = np.random.randint(self.num_exp) 66 | idx = max(min(idx, self.num_imgs[exp_index] - self.seq_len * 2), self.seq_len_dvs * 3) 67 | img_seq = [] 68 | for i in range(self.seq_len_aps): 69 | img_loc = self.data_dir + str(self.exp_idx[exp_index]) + '/dvs_frames/' + \ 70 | self.total_aps[exp_index][idx - self.seq_len_aps + i] 71 | img_seq += [Image.open(img_loc).convert('RGB')] 72 | img_seq_pt = [] 73 | 74 | if self.transform: 75 | for images in img_seq: 76 | img_seq_pt += [torch.unsqueeze(self.transform(images), 0)] 77 | 78 | img_seq = torch.cat(img_seq_pt, dim=0) 79 | t_stamps = self.raw_pos[exp_index][:, 0] 80 | t_target = self.t_aps[exp_index][idx] 81 | 82 | idx_pos = max(np.searchsorted(t_stamps, t_target), self.seq_len_aps) 83 | pos_seq = self.data_pos[exp_index][idx_pos - self.seq_len_gps:idx_pos, :] 84 | pos_seq = torch.from_numpy(pos_seq.astype('float32')) 85 | 86 | t_stamps = self.raw_head[exp_index][:, 0] 87 | t_target = self.t_aps[exp_index][idx] 88 | idx_head = max(np.searchsorted(t_stamps, t_target), self.seq_len_aps) 89 | head_seq = self.data_head[exp_index][idx_head - self.seq_len_gps:idx_pos, 1] 90 | head_seq = torch.from_numpy(head_seq.astype('float32')).reshape(-1,1) 91 | 92 | idx_dvs = np.searchsorted(self.t_dvs[exp_index], t_target, sorter=None) - 1 93 | t_stamps = self.t_dvs[exp_index][idx_dvs] 94 | dvs_seq = torch.zeros(self.seq_len_dvs * 3, 2, 130, 173) 95 | for i in range(self.seq_len_dvs): 96 | dvs_path = self.data_dir + str(self.exp_idx[exp_index]) + '/dvs_7ms_3seq/' \ 97 | + self.total_dvs[exp_index][idx_dvs - self.seq_len_dvs + i + 1] 98 | dvs_buf = torch.load(dvs_path) 99 | dvs_buf = dvs_buf.permute([1, 0, 2, 3]) 100 | dvs_seq[i * 3: (i + 1) * 3] = torch.nn.functional.avg_pool2d(dvs_buf, 2) 101 | ids = int((t_stamps - self.t_dvs[exp_index][0]) / self.duration[exp_index] * self.nclass) 102 | 103 | ids = np.clip(ids, a_min=0, a_max= self.nclass - 1) 104 | ids = np.array(ids) 105 | ids = torch.from_numpy(ids).type(torch.long) 106 | return (img_seq, pos_seq, dvs_seq, head_seq), ids 107 | 108 | def Data(data_path=None, batch_size=None, exp_idx=None, is_shuffle=True, normalize = None, nclass=None, seq_len_aps=None, 109 | seq_len_dvs=None, seq_len_gps=None, seq_len_head = None, seq_len_time = None): 110 | dataset = SequentialDataset(data_dir=data_path, 111 | exp_idx=exp_idx, 112 | transform=torchvision.transforms.Compose([ 113 | torchvision.transforms.Resize((240, 320)), 114 | torchvision.transforms.ToTensor(), 115 | normalize, 116 | ]), 117 | nclass=nclass, 118 | seq_len_aps=seq_len_aps, 119 | seq_len_dvs=seq_len_dvs, 120 | seq_len_gps=seq_len_gps, 121 | seq_len_head = seq_len_head, 122 | seq_len_time = seq_len_time) 123 | 124 | data_loader = torch.utils.data.DataLoader(dataset, 125 | shuffle=is_shuffle, 126 | drop_last=True, 127 | batch_size=batch_size, 128 | pin_memory=True, 129 | num_workers=8) 130 | return data_loader 131 | 132 | def Data_brightness(data_path=None, batch_size=None, exp_idx=None, is_shuffle=True, normalize = None, nclass=None, seq_len_aps=None, 133 | seq_len_dvs=None, seq_len_gps=None, seq_len_head = None, seq_len_time = None): 134 | dataset = SequentialDataset(data_dir=data_path, 135 | exp_idx=exp_idx, 136 | transform=torchvision.transforms.Compose([ 137 | torchvision.transforms.Resize((240, 320)), 138 | torchvision.transforms.ToTensor(), 139 | normalize, 140 | torchvision.transforms.ColorJitter(brightness=0.5), 141 | ]), 142 | nclass=nclass, 143 | seq_len_aps=seq_len_aps, 144 | seq_len_dvs=seq_len_dvs, 145 | seq_len_gps=seq_len_gps, 146 | seq_len_head=seq_len_head, 147 | seq_len_time=seq_len_time) 148 | 149 | data_loader = torch.utils.data.DataLoader(dataset, 150 | shuffle=is_shuffle, 151 | drop_last=True, 152 | batch_size=batch_size, 153 | pin_memory=True, 154 | num_workers=8) 155 | return data_loader 156 | 157 | def Data_mask(data_path=None, batch_size=None, exp_idx=None, is_shuffle=True, normalize = None, nclass=None, seq_len_aps=None, 158 | seq_len_dvs=None, seq_len_gps=None, seq_len_head = None, seq_len_time = None): 159 | dataset = SequentialDataset(data_dir=data_path, 160 | exp_idx=exp_idx, 161 | transform=torchvision.transforms.Compose([ 162 | torchvision.transforms.Resize((240, 320)), 163 | torchvision.transforms.ToTensor(), 164 | normalize, 165 | torchvision.transforms.RandomCrop(size=128, padding=128), 166 | ]), 167 | nclass=nclass, 168 | seq_len_aps=seq_len_aps, 169 | seq_len_dvs=seq_len_dvs, 170 | seq_len_gps=seq_len_gps, 171 | seq_len_head=seq_len_head, 172 | seq_len_time=seq_len_time) 173 | 174 | data_loader = torch.utils.data.DataLoader(dataset, 175 | shuffle=is_shuffle, 176 | drop_last=True, 177 | batch_size=batch_size, 178 | pin_memory=True, 179 | num_workers=8) 180 | return data_loader 181 | 182 | def setup_seed(seed): 183 | torch.manual_seed(seed) 184 | torch.cuda.manual_seed(seed) 185 | np.random.seed(seed) 186 | random.seed(seed) 187 | torch.backends.cudnn.deterministic = True 188 | 189 | def accuracy(output, target, topk=(1,)): 190 | with torch.no_grad(): 191 | _, pred = output.topk(max(topk), dim=1, largest=True, sorted=True) 192 | pred = pred.t() 193 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 194 | res = [correct[:k].sum().item() for k in topk] 195 | return res 196 | 197 | def compute_matches(retrieved_all, ground_truth_info): 198 | matches=[] 199 | itr=0 200 | for retr in retrieved_all: 201 | if (retr == ground_truth_info[itr]): 202 | matches.append(1) 203 | else: 204 | matches.append(0) 205 | itr=itr+1 206 | return matches 207 | 208 | def compute_precision_recall(matches,scores_all): 209 | precision, recall, _ = precision_recall_curve(matches, scores_all) 210 | return precision, recall 211 | -------------------------------------------------------------------------------- /src/main/main_mhnn.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.utils.model_zoo 3 | import time 4 | from src.model.mhnn import * 5 | from src.tools.utils import * 6 | from src.config.config_utils import * 7 | 8 | import argparse 9 | import logging 10 | logging.basicConfig(level=logging.INFO) 11 | log = logging.getLogger() 12 | 13 | setup_seed(0) 14 | 15 | parser = argparse.ArgumentParser(description='mhnn', argument_default=argparse.SUPPRESS) 16 | parser.add_argument('--config_file', type=str, required=True, help='Configure file path') 17 | 18 | def main_mhnn(options: argparse.Namespace): 19 | config = get_config(options.config_file, logger=log)['main'] 20 | config = update_config(config, options.__dict__) 21 | 22 | cnn_arch = str(config['cnn_arch']) 23 | num_epoch = int(config['num_epoch']) 24 | batch_size = int(config['batch_size']) 25 | num_class = int(config['num_class']) 26 | rnn_num = int(config['rnn_num']) 27 | cann_num = int(config['cann_num']) 28 | reservoir_num = int(config['reservoir_num']) 29 | num_iter = int(config['num_iter']) 30 | spiking_threshold = float(config['spiking_threshold']) 31 | sparse_lambdas = int(config['sparse_lambdas']) 32 | lr = float(config['lr']) 33 | r = float(config['r']) 34 | 35 | ann_pre_load = get_bool_from_config(config, 'ann_pre_load') 36 | snn_pre_load = get_bool_from_config(config, 'snn_pre_load') 37 | re_trained = get_bool_from_config(config, 're_trained') 38 | 39 | seq_len_aps = int(config['seq_len_aps']) 40 | seq_len_gps = int(config['seq_len_gps']) 41 | seq_len_dvs = int(config['seq_len_dvs']) 42 | seq_len_head = int(config['seq_len_head']) 43 | seq_len_time = int(config['seq_len_time']) 44 | 45 | dvs_expand = int(config['dvs_expand']) 46 | expand_len = least_common_multiple([seq_len_aps, seq_len_dvs * dvs_expand, seq_len_gps]) 47 | 48 | test_exp_idx = [] 49 | for idx in config['test_exp_idx']: 50 | if idx != ',': 51 | test_exp_idx.append(int(idx)) 52 | 53 | train_exp_idx = [] 54 | for idxt in config['train_exp_idx']: 55 | if idxt != ',': 56 | train_exp_idx.append(int(idxt)) 57 | 58 | data_path = str(config['data_path']) 59 | snn_path = str(config['snn_path']) 60 | hnn_path = str(config['hnn_path']) 61 | model_saving_file_name = str(config['model_saving_file_name']) 62 | 63 | w_fps = int(config['w_fps']) 64 | w_gps = int(config['w_gps']) 65 | w_dvs = int(config['w_dvs']) 66 | w_head = int(config['w_head']) 67 | w_time = int(config['w_time']) 68 | 69 | device_id = str(config['device_id']) 70 | 71 | normalize = torchvision.transforms.Normalize(mean=[0.3537, 0.3537, 0.3537], 72 | std=[0.3466, 0.3466, 0.3466]) 73 | 74 | train_loader = Data(data_path, batch_size=batch_size, exp_idx=train_exp_idx, is_shuffle=True, 75 | normalize=normalize, nclass=num_class, 76 | seq_len_aps=seq_len_aps, seq_len_dvs=seq_len_dvs, seq_len_gps=seq_len_gps, 77 | seq_len_head=seq_len_head, seq_len_time = seq_len_time) 78 | 79 | test_loader = Data(data_path, batch_size=batch_size, exp_idx=test_exp_idx, is_shuffle=True, 80 | normalize=normalize, nclass=num_class, 81 | seq_len_aps=seq_len_aps, seq_len_dvs=seq_len_dvs, seq_len_gps=seq_len_gps, 82 | seq_len_head=seq_len_head, seq_len_time = seq_len_time) 83 | 84 | mhnn = MHNN(device = device, 85 | cnn_arch = cnn_arch, 86 | num_epoch = num_epoch, 87 | batch_size = batch_size, 88 | num_class = num_class, 89 | rnn_num = rnn_num, 90 | cann_num = cann_num, 91 | reservoir_num = reservoir_num, 92 | spiking_threshold = spiking_threshold, 93 | sparse_lambdas = sparse_lambdas, 94 | r = r, 95 | lr = lr, 96 | w_fps = w_fps, 97 | w_gps = w_gps, 98 | w_dvs = w_dvs, 99 | w_head = w_head, 100 | w_time = w_time, 101 | seq_len_aps = seq_len_aps, 102 | seq_len_gps = seq_len_gps, 103 | seq_len_dvs = seq_len_dvs, 104 | seq_len_head = seq_len_head, 105 | seq_len_time = seq_len_time, 106 | dvs_expand = dvs_expand, 107 | expand_len = expand_len, 108 | train_exp_idx = train_exp_idx, 109 | test_exp_idx = test_exp_idx, 110 | data_path = data_path, 111 | snn_path = snn_path, 112 | hnn_path = hnn_path, 113 | num_iter = num_iter, 114 | ann_pre_load = ann_pre_load, 115 | snn_pre_load = snn_pre_load, 116 | re_trained = re_trained) 117 | 118 | mhnn.cann_init(np.concatenate((train_loader.dataset.data_pos[0],train_loader.dataset.data_head[0][:,1].reshape(-1,1)),axis=1)) 119 | 120 | mhnn.to(device) 121 | 122 | optimizer = torch.optim.Adam(mhnn.parameters(), lr) 123 | 124 | lr_schedule = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) 125 | 126 | criterion = nn.CrossEntropyLoss() 127 | 128 | record = {} 129 | record['loss'], record['top1'], record['top5'], record['top10'] = [], [], [], [] 130 | best_test_acc1, best_test_acc5, best_recall, best_test_acc10 = 0., 0., 0, 0 131 | 132 | train_iters = iter(train_loader) 133 | iters = 0 134 | start_time = time.time() 135 | print(device) 136 | 137 | import torchmetrics 138 | 139 | best_recall = 0. 140 | test_acc = torchmetrics.Accuracy() 141 | 142 | test_recall = torchmetrics.Recall(average='none', num_classes=num_class) 143 | test_precision = torchmetrics.Precision(average='none', num_classes=num_class) 144 | 145 | for epoch in range(num_epoch): 146 | 147 | ############### 148 | ## for training 149 | ############### 150 | 151 | running_loss = 0. 152 | counts = 1. 153 | acc1_record, acc5_record, acc10_record = 0., 0., 0. 154 | while iters < num_iter: 155 | mhnn.train() 156 | optimizer.zero_grad() 157 | 158 | try: 159 | inputs, target = next(train_iters) 160 | except StopIteration: 161 | train_iters = iter(train_loader) 162 | inputs, target = next(train_iters) 163 | 164 | outputs, outs = mhnn(inputs, epoch=epoch) 165 | 166 | class_loss = criterion(outputs, target.to(device)) 167 | 168 | sparse_loss = 0. 169 | w_module = (w_fps, w_dvs, w_gps, w_gps) 170 | for (i, l) in enumerate(outs): 171 | sparse_loss += (l.mean() - r)**2 * w_module[i] 172 | 173 | loss = class_loss + sparse_lambdas * sparse_loss 174 | 175 | loss.backward() 176 | 177 | optimizer.step() 178 | 179 | running_loss += loss.cpu().item() 180 | 181 | acc1, acc5, acc10 = accuracy(outputs.cpu(), target, topk=(1, 5, 10)) 182 | acc1, acc5, acc10 = acc1 / len(outputs), acc5 / len(outputs), acc10 / len(outputs) 183 | 184 | acc1_record += acc1 185 | acc5_record += acc5 186 | acc10_record += acc10 187 | 188 | counts += 1 189 | iters += 1. 190 | iters = 0 191 | 192 | record['loss'].append(loss.item()) 193 | print('\n\nTime elaspe:', time.time() - start_time, 's') 194 | print( 195 | 'Training epoch %.1d, training loss :%.4f, sparse loss :%.4f, training Top1 acc: %.4f, training Top5 acc: %.4f' % 196 | (epoch, running_loss / (num_iter), sparse_lambdas * sparse_loss, acc1_record / counts, acc5_record / counts)) 197 | 198 | lr_schedule.step() 199 | start_time = time.time() 200 | 201 | ############## 202 | ## for testing 203 | ############## 204 | 205 | running_loss = 0. 206 | mhnn.eval() 207 | 208 | with torch.no_grad(): 209 | 210 | acc1_record, acc5_record, acc10_record = 0., 0., 0. 211 | counts = 1. 212 | 213 | for batch_idx, (inputs, target) in enumerate(test_loader): 214 | 215 | outputs, _ = mhnn(inputs, epoch=epoch) 216 | 217 | loss = criterion(outputs.cpu(), target) 218 | 219 | running_loss += loss.item() 220 | acc1, acc5, acc10 = accuracy(outputs.cpu(), target, topk=(1, 5, 10)) 221 | acc1, acc5, acc10 = acc1 / len(outputs), acc5 / len(outputs), acc10 / len(outputs) 222 | acc1_record += acc1 223 | acc5_record += acc5 224 | acc10_record += acc10 225 | 226 | counts += 1 227 | outputs = outputs.cpu() 228 | 229 | test_acc(outputs.argmax(1), target) 230 | test_recall(outputs.argmax(1), target) 231 | test_precision(outputs.argmax(1), target) 232 | 233 | compute_num = 100 234 | if batch_idx % compute_num == 0 and batch_idx > 0: 235 | #print('time:', (time.time() - start_time) / compute_num) 236 | start_time = time.time() 237 | 238 | total_acc = test_acc.compute().mean() 239 | total_recall = test_recall.compute().mean() 240 | total_precison = test_precision.compute().mean() 241 | 242 | test_precision.reset() 243 | test_recall.reset() 244 | test_acc.reset() 245 | 246 | acc1_record = acc1_record / counts 247 | acc5_record = acc5_record / counts 248 | acc10_record = acc10_record / counts 249 | 250 | record['top1'].append(acc1_record) 251 | record['top5'].append(acc5_record) 252 | record['top10'].append(acc10_record) 253 | 254 | 255 | print('Current best Top1: ', best_test_acc1, 'Best Top5:', best_test_acc5) 256 | 257 | if epoch > 4: 258 | if best_test_acc1 < acc1_record: 259 | best_test_acc1 = acc1_record 260 | print('Achiving the best Top1, saving...', best_test_acc1) 261 | 262 | if best_test_acc5 < acc5_record: 263 | # best_test_acc1 = acc1_record 264 | best_test_acc5 = acc5_record 265 | print('Achiving the best Top5, saving...', best_test_acc5) 266 | 267 | if best_recall < total_recall: 268 | # best_test_acc1 = acc1_record 269 | best_recall = total_recall 270 | print('Achiving the best recall, saving...', best_recall) 271 | 272 | if best_test_acc10 < acc10_record: 273 | best_test_acc10 = acc10_record 274 | 275 | state = { 276 | 'net': mhnn.state_dict(), 277 | 'snn': mhnn.snn.state_dict(), 278 | 'record': record, 279 | 'best_recall': best_recall, 280 | 'best_acc1': best_test_acc1, 281 | 'best_acc5': best_test_acc5, 282 | 'best_acc10': best_test_acc10 283 | } 284 | if not os.path.isdir('../../checkpoint'): 285 | os.mkdir('../../checkpoint') 286 | torch.save(state, '../../checkpoint/' + model_saving_file_name + '.t7') 287 | 288 | 289 | if __name__ == '__main__': 290 | options, unknowns = parser.parse_known_args() 291 | main_mhnn(options) 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | -------------------------------------------------------------------------------- /src/model/mhnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.model_zoo 4 | import torchvision.models as models 5 | from src.model.snn import * 6 | from src.model.cann import CANN 7 | from src.tools.utils import * 8 | 9 | # gpu environment 10 | #os.environ['CUDA_VISIBLE_DEVICES'] = "3" 11 | device = torch.device('cuda' if torch.cuda.is_available() else "cpu") 12 | 13 | class MHNN(nn.Module): 14 | def __init__(self, **kwargs): 15 | super(MHNN, self).__init__() 16 | 17 | self.cnn_arch = kwargs.get('cnn_arch') 18 | self.num_class = kwargs.get('num_class') 19 | self.cann_num = kwargs.get('cann_num') 20 | self.rnn_num = kwargs.get('rnn_num') 21 | self.lr = kwargs.get('lr') 22 | self.batch_size = kwargs.get('batch_size') 23 | self.sparse_lambdas = kwargs.get('sparse_lambdas') 24 | self.r = kwargs.get('r') 25 | 26 | self.reservoir_num = kwargs.get('reservoir_num') 27 | self.threshold = kwargs.get('spiking_threshold') 28 | 29 | self.num_epoch = kwargs.get('num_epoch') 30 | self.num_iter = kwargs.get('num_iter') 31 | self.w_fps = kwargs.get('w_fps') 32 | self.w_gps = kwargs.get('w_gps') 33 | self.w_dvs = kwargs.get('w_dvs') 34 | self.w_head = kwargs.get('w_head') 35 | self.w_time = kwargs.get('w_time') 36 | 37 | self.seq_len_aps = kwargs.get('seq_len_aps') 38 | self.seq_len_gps = kwargs.get('seq_len_gps') 39 | self.seq_len_dvs = kwargs.get('seq_len_dvs') 40 | self.seq_len_head = kwargs.get('seq_len_head') 41 | self.seq_len_time = kwargs.get('seq_len_time') 42 | self.dvs_expand = kwargs.get('dvs_expand') 43 | 44 | self.ann_pre_load = kwargs.get('ann_pre_load') 45 | self.snn_pre_load = kwargs.get('snn_pre_load') 46 | self.re_trained = kwargs.get('re_trained') 47 | 48 | self.train_exp_idx = kwargs.get('train_exp_idx') 49 | self.test_exp_idx = kwargs.get('test_exp_idx') 50 | 51 | self.data_path = kwargs.get('data_path') 52 | self.snn_path = kwargs.get('snn_path') 53 | 54 | #self.device = kwargs.get('device') 55 | self.device = device 56 | 57 | if self.ann_pre_load: 58 | print("=> Loading pre-trained model '{}'".format(self.cnn_arch)) 59 | self.cnn = models.__dict__[self.cnn_arch](pretrained=self.ann_pre_load) 60 | else: 61 | print("=> Using randomly inizialized model '{}'".format(self.cnn_arch)) 62 | self.cnn = models.__dict__[self.cnn_arch](pretrained=self.ann_pre_load) 63 | 64 | if self.cnn_arch == "mobilenet_v2": 65 | """ MobileNet """ 66 | self.feature_dim = self.cnn.classifier[1].in_features 67 | self.cnn.classifier[1] = nn.Identity() 68 | 69 | elif self.cnn_arch == "resnet50": 70 | """ Resnet50 """ 71 | 72 | self.feature_dim = 512 73 | 74 | # self.cnn.layer1 = nn.Identity() 75 | self.cnn.layer2 = nn.Identity() 76 | self.cnn.layer3 = nn.Identity() 77 | self.cnn.layer4 = nn.Identity() 78 | self.cnn.fc = nn.Identity() 79 | 80 | self.cnn.layer1[1] = nn.Identity() 81 | self.cnn.layer1[2] = nn.Identity() 82 | 83 | # self.cnn.layer2[0] = nn.Identity() 84 | # self.cnn.layer2[0].conv2 = nn.Identity() 85 | # self.cnn.layer2[0].bn2 = nn.Identity() 86 | 87 | fc_inputs = 256 88 | self.cnn.fc = nn.Linear(fc_inputs,self.feature_dim) 89 | 90 | else: 91 | print("=> Please check model name or configure architecture for feature extraction only, exiting...") 92 | exit() 93 | 94 | for param in self.cnn.parameters(): 95 | param.requires_grad = self.re_trained 96 | 97 | ############# 98 | # SNN module 99 | ############# 100 | self.snn = SNN(device = self.device).to(self.device) 101 | self.snn_out_dim = self.snn.fc2.weight.size()[1] 102 | self.ann_out_dim = self.feature_dim 103 | self.cann_out_dim = 4 * self.cann_num 104 | self.reservior_inp_num = self.ann_out_dim + self.snn_out_dim + self.cann_out_dim 105 | self.LN = nn.LayerNorm(self.reservior_inp_num) 106 | if self.snn_pre_load: 107 | self.snn.load_state_dict(torch.load(self.snn_path)['snn']) 108 | 109 | ############# 110 | # CANN module 111 | ############# 112 | self.cann_num = self.cann_num 113 | self.cann = None 114 | self.num_class = self.num_class 115 | 116 | ############# 117 | # MLSM module 118 | ############# 119 | self.input_size = self.feature_dim 120 | self.reservoir_num = self.reservoir_num 121 | 122 | self.threshold = 0.5 123 | self.decay = nn.Parameter(torch.rand(self.reservoir_num)) 124 | 125 | self.K = 128 126 | self.num_block = 5 127 | self.num_blockneuron = int(self.reservoir_num / self.num_block) 128 | 129 | self.decay_scale = 0.5 130 | self.beta_scale = 0.1 131 | 132 | self.thr_base1 = self.threshold 133 | 134 | self.thr_beta1 = nn.Parameter(self.beta_scale * torch.rand(self.reservoir_num)) 135 | 136 | self.thr_decay1 = nn.Parameter(self.decay_scale * torch.rand(self.reservoir_num)) 137 | 138 | self.ref_base1 = self.threshold 139 | self.ref_beta1 = nn.Parameter(self.beta_scale * torch.rand(self.reservoir_num)) 140 | self.ref_decay1 = nn.Parameter(self.decay_scale * torch.rand(self.reservoir_num)) 141 | 142 | self.cur_base1 = 0 143 | self.cur_beta1 = nn.Parameter(self.beta_scale * torch.rand(self.reservoir_num)) 144 | self.cur_decay1 = nn.Parameter(self.decay_scale * torch.rand(self.reservoir_num)) 145 | 146 | self.project = nn.Linear(self.reservior_inp_num, self.reservoir_num) 147 | 148 | self.project_mask_matrix = torch.zeros((self.reservior_inp_num, self.reservoir_num)) 149 | 150 | 151 | input_node_list = [0, self.ann_out_dim, self.snn_out_dim, self.cann_num * 2, self.cann_num] 152 | 153 | input_cum_list = np.cumsum(input_node_list) 154 | 155 | for i in range(len(input_cum_list) - 1): 156 | self.project_mask_matrix[input_cum_list[i]:input_cum_list[i + 1], 157 | self.num_blockneuron * i:self.num_blockneuron * (i + 1)] = 1 158 | 159 | self.project.weight.data = self.project.weight.data * self.project_mask_matrix.t() 160 | 161 | self.lateral_conn = nn.Linear(self.reservoir_num, self.reservoir_num) 162 | 163 | # control the ratio of # of lateral conn. 164 | self.lateral_conn_mask = torch.rand(self.reservoir_num, self.reservoir_num) > 0.8 165 | # remove self-recurrent conn. 166 | self.lateral_conn_mask = self.lateral_conn_mask * (1 - torch.eye(self.reservoir_num, self.reservoir_num)) 167 | # adjust the intial weight of conn. 168 | self.lateral_conn.weight.data = 1e-3 * self.lateral_conn.weight.data * self.lateral_conn_mask.T 169 | 170 | ############# 171 | # readout module 172 | ############# 173 | 174 | # self.mlp1 = nn.Linear(self.reservoir_num, 256) 175 | # self.mlp2 = nn.Linear(256, self.num_class) 176 | self.mlp = nn.Linear(self.reservoir_num, self.num_class) 177 | 178 | def cann_init(self, data): 179 | self.cann = CANN(data) 180 | 181 | def lr_initial_schedule(self, lrs=1e-3): 182 | hyper_param_list = ['decay', 183 | 'thr_beta1', 'thr_decay1', 184 | 'ref_beta1', 'ref_decay1', 185 | 'cur_beta1', 'cur_decay1'] 186 | hyper_params = list(filter(lambda x: x[0] in hyper_param_list, self.named_parameters())) 187 | base_params = list(filter(lambda x: x[0] not in hyper_param_list, self.named_parameters())) 188 | hyper_params = [x[1] for x in hyper_params] 189 | base_params = [x[1] for x in base_params] 190 | optimizer = torch.optim.SGD( 191 | [ 192 | {'params': base_params, 'lr': lrs}, 193 | {'params': hyper_params, 'lr': lrs / 2}, 194 | ], lr=lrs, momentum=0.9, weight_decay=1e-7 195 | ) 196 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(self.num_epoch)) 197 | return optimizer, scheduler 198 | 199 | def wta_mem_update(self, fc, fv, k, inputx, spike, mem, thr, ref, last_cur): 200 | # scaling_lateral_inputs = 0.1 201 | #state = fc(inputx) + fv(spike) * scaling_lateral_inputs 202 | state = fc(inputx) + fv(spike) 203 | 204 | mem = (mem - spike * ref) * self.decay + state + last_cur 205 | # ratio of # of winners 206 | q0 = (100 -20)/100 207 | mem = mem.reshape(self.batch_size, self.num_blockneuron, -1) 208 | nps = torch.quantile(mem, q=q0, keepdim=True, axis=1) 209 | # nps =mem.max(axis=1,keepdim=True)[0] - thr 210 | 211 | # allows that only winner can fire spikes 212 | mem = F.relu(mem - nps).reshape(self.batch_size, -1) 213 | 214 | # spiking firing function 215 | spike = act_fun(mem - thr) 216 | return spike.float(), mem 217 | 218 | def trace_update(self, trace, spike, decay): 219 | return trace * decay + spike 220 | 221 | def forward(self, inp, epoch=100): 222 | aps_inp = inp[0].to(self.device) 223 | gps_inp = inp[1] 224 | dvs_inp = inp[2].to(self.device) 225 | head_inp = inp[3].to(self.device) 226 | 227 | batch_size, seq_len, channel, w, h = aps_inp.size() 228 | if self.w_fps > 0: 229 | aps_inp = aps_inp.view(batch_size * seq_len, channel, w, h) 230 | aps_out = self.cnn(aps_inp) 231 | out1 = aps_out.reshape(batch_size, self.seq_len_aps, -1).permute([1, 0, 2]) 232 | else: 233 | out1 = torch.zeros(self.seq_len_aps, batch_size, -1, device=self.device).to(torch.float32) 234 | 235 | if self.w_dvs > 0: 236 | out2 = self.snn(dvs_inp, out_mode='time') 237 | else: 238 | out2 = torch.zeros(self.seq_len_dvs * 3, batch_size, self.snn_out_dim, device=self.device).to(torch.float32) 239 | 240 | ### CANN module 241 | 242 | 243 | if self.w_gps + self.w_head + self.w_time > 0: 244 | gps_record = [] 245 | for idx in range(batch_size): 246 | buf = self.cann.update(torch.cat((gps_inp[idx],head_inp[idx].cpu()),axis=1), trajactory_mode=True) 247 | gps_record.append(buf[None, :, :, :]) 248 | gps_out = torch.from_numpy(np.concatenate(gps_record)).cuda() 249 | gps_out = gps_out.permute([1, 0, 2, 3]).reshape(self.seq_len_gps, batch_size, -1) 250 | else: 251 | gps_out = torch.zeros((self.seq_len_gps, batch_size, self.cann_out_dim), device=self.device) 252 | 253 | # A generic CANN module was used for rapid testing; CANN1D/2D are provided in cann.py 254 | out3 = gps_out[:, :, self.cann_num:self.cann_num * 3].to(self.device).to(torch.float32) # position 255 | out4 = gps_out[:, :, : self.cann_num].to(self.device).to(torch.float32) # time 256 | out5 = gps_out[:, :, - self.cann_num:].to(self.device).to(torch.float32) # direction 257 | 258 | out3 *= self.w_gps 259 | out4 *= self.w_time 260 | out5 *= self.w_head 261 | 262 | expand_len = int(self.seq_len_gps * self.seq_len_aps / np.gcd(self.seq_len_gps, self.seq_len_dvs) * self.dvs_expand) 263 | input_num = self.feature_dim + self.snn.fc2.weight.size()[1] + self.cann_num * self.dvs_expand 264 | 265 | r_spike = r_mem = r_sumspike = torch.zeros(batch_size, self.reservoir_num, device=self.device) 266 | 267 | thr_trace = torch.zeros(batch_size, self.reservoir_num, device=self.device) 268 | ref_trace = torch.zeros(batch_size, self.reservoir_num, device=self.device) 269 | cur_trace = torch.zeros(batch_size, self.reservoir_num, device=self.device) 270 | 271 | K_winner = self.K * (1 + np.clip(self.num_epoch - epoch, a_min=0, a_max=self.num_epoch) / self.num_epoch) 272 | 273 | out1_zeros = torch.zeros_like(out1[0], device=self.device) 274 | out3_zeros = torch.zeros_like(out3[0], device=self.device) 275 | out4_zeros = torch.zeros_like(out4[0], device=self.device) 276 | out5_zeros = torch.zeros_like(out5[0], device=self.device) 277 | 278 | self.project.weight.data = self.project.weight.data * self.project_mask_matrix.t().cuda() 279 | self.lateral_conn.weight.data = self.lateral_conn.weight.data * self.lateral_conn_mask.T.cuda() 280 | 281 | self.decay.data = torch.clamp(self.decay.data, min=0, max=1.) 282 | self.thr_decay1.data = torch.clamp(self.thr_decay1.data, min=0, max=1.) 283 | self.ref_decay1.data = torch.clamp(self.ref_decay1.data, min=0, max=1) 284 | self.cur_decay1.data = torch.clamp(self.cur_decay1.data, min=0, max=1) 285 | 286 | for step in range(expand_len): 287 | 288 | idx = step % 3 289 | # simulate multimodal information with different time scales 290 | if idx == 2: 291 | combined_input = torch.cat((out1[step // 3], out2[step], out3[step // 3], out4[step // 3], out5[step // 3]), axis=1) 292 | else: 293 | combined_input = torch.cat((out1_zeros, out2[step],out3_zeros, out4_zeros, out5_zeros), axis=1) 294 | 295 | thr = self.thr_base1 + thr_trace * self.thr_beta1 296 | # ref = self.ref_base1 + ref_trace * self.ref_beta1 # option: ref = self.ref_base1 297 | ref = self.ref_base1 298 | cur = self.cur_base1 + cur_trace * self.cur_beta1 299 | 300 | inputx = combined_input.float() 301 | r_spike, r_mem = self.wta_mem_update(self.project, self.lateral_conn, K_winner, inputx, r_spike, r_mem, 302 | thr, ref, cur) 303 | thr_trace = self.trace_update(thr_trace, r_spike, self.thr_decay1) 304 | ref_trace = self.trace_update(ref_trace, r_spike, self.ref_decay1) 305 | cur_trace = self.trace_update(cur_trace, r_spike, self.cur_decay1) 306 | r_sumspike = r_sumspike + r_spike 307 | 308 | # cat_out = F.dropout(r_sumspike, p=0.5, training=self.training) 309 | # out1 = self.mlp1(r_sumspike).relu() 310 | # out2 = self.mlp2(out1) 311 | out2 = self.mlp(r_sumspike) 312 | 313 | neuron_pop = r_sumspike.reshape(batch_size, -1, self.num_blockneuron).permute([1, 0, 2]) 314 | return out2, (neuron_pop[0], neuron_pop[1]) 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | --------------------------------------------------------------------------------