├── src
├── config
│ ├── __init__.py
│ ├── config_utils.py
│ ├── room_setting.ini
│ ├── forest_setting.ini
│ └── corridor_setting.ini
├── model
│ ├── snn.py
│ ├── cann.py
│ └── mhnn.py
├── tools
│ └── utils.py
└── main
│ └── main_mhnn.py
├── LICENSE
└── README.md
/src/config/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
--------------------------------------------------------------------------------
/src/config/config_utils.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | from os.path import isfile, join
3 | from config import ROOT_DIR
4 |
5 | def get_config(config_file, logger):
6 | config = configparser.ConfigParser()
7 | configFile = join(ROOT_DIR, config_file)
8 | logger.info(f'The path of the configure file is {configFile}')
9 | assert isfile(configFile)
10 | config.read(configFile)
11 | return config
12 |
13 | def update_config(config, updates):
14 | for k in updates:
15 | config[k] = updates[k]
16 | return config
17 |
18 | def get_bool_from_config(config, key, default_value = False):
19 | if key in config:
20 | return config[key] == 'True'
21 | else:
22 | return default_value
23 |
--------------------------------------------------------------------------------
/src/config/room_setting.ini:
--------------------------------------------------------------------------------
1 | [main]
2 | cnn_arch = mobilenet_v2
3 | num_class = 100
4 | rnn_num = 128
5 | cann_num = 128
6 | lr = 2.5e-4
7 | reservoir_num = 4000
8 | sparse_lambdas = 5
9 | r = 0.15
10 | batch_size = 20
11 | spiking_threshold = 0.5
12 |
13 | num_epoch = 30
14 | num_iter = 100
15 |
16 | train_exp_idx = 1
17 | test_exp_idx = 2
18 |
19 | ann_pre_load = True
20 | snn_pre_load = False
21 | re_trained = True
22 |
23 | w_fps = 1
24 | w_gps = 1
25 | w_dvs = 1
26 | w_head = 1
27 | w_time =1
28 |
29 | dvs_seq = 3
30 | seq_len_aps = 3
31 | seq_len_gps = 3
32 | seq_len_dvs = 3
33 | seq_len_head = 3
34 | seq_len_time = 3
35 | dvs_expand = 3
36 |
37 | data_path = /data/NeuroVPR_Datasets/room/room_v
38 | snn_path = ./pretrained_model/snn_room_model.t7
39 | hnn_path = ./pretrained_model/mhnn_room_model.t7
40 | model_saving_file_name = mhnn_room_v1
41 | device_id = 7
42 |
43 |
44 |
--------------------------------------------------------------------------------
/src/config/forest_setting.ini:
--------------------------------------------------------------------------------
1 | [main]
2 | cnn_arch = mobilenet_v2
3 | num_class = 100
4 | rnn_num = 128
5 | cann_num = 128
6 | lr = 2.5e-4
7 | reservoir_num = 4000
8 | sparse_lambdas = 5
9 | r = 0.15
10 | batch_size = 20
11 | spiking_threshold = 0.5
12 |
13 | num_epoch = 30
14 | num_iter = 100
15 |
16 | train_exp_idx = 1
17 | test_exp_idx = 2
18 |
19 | ann_pre_load = True
20 | snn_pre_load = False
21 | re_trained = True
22 |
23 | w_fps = 1
24 | w_gps = 1
25 | w_dvs = 1
26 | w_head = 1
27 | w_time =1
28 |
29 | dvs_seq = 3
30 | seq_len_aps = 3
31 | seq_len_gps = 3
32 | seq_len_dvs = 3
33 | seq_len_time = 3
34 | seq_len_head = 3
35 | dvs_expand = 3
36 |
37 | data_path = /data/NeuroVPR_Datasets/campus/forest/forest_v
38 | snn_path = ./pretrained_model/snn_forest_model.t7
39 | hnn_path = ./pretrained_model/mhnn_forest_model.t7
40 | model_saving_file_name = mhnn_forest_v1
41 | device_id = 7
42 |
43 |
44 |
--------------------------------------------------------------------------------
/src/config/corridor_setting.ini:
--------------------------------------------------------------------------------
1 | [main]
2 | cnn_arch = mobilenet_v2
3 | num_class = 100
4 | rnn_num = 128
5 | cann_num = 128
6 | lr = 2.5e-4
7 |
8 | reservoir_num = 4000
9 | spiking_threshold = 0.5
10 | sparse_lambdas = 5
11 | r = 0.15
12 |
13 | batch_size = 20
14 | num_epoch = 30
15 | num_iter = 100
16 |
17 | train_exp_idx = 5,7
18 | test_exp_idx = 9
19 |
20 | ann_pre_load = True
21 | snn_pre_load = False
22 | re_trained = True
23 |
24 | w_fps = 1
25 | w_gps = 1
26 | w_dvs = 1
27 | w_head = 1
28 | w_time =1
29 |
30 | dvs_seq = 3
31 | seq_len_aps = 3
32 | seq_len_gps = 3
33 | seq_len_dvs = 3
34 | seq_len_time = 3
35 | seq_len_head = 3
36 | dvs_expand = 3
37 |
38 | data_path = /data/NeuroVPR_Datasets/corridor/floor3/floor3_v
39 | snn_path = ./pretrained_snn_model/snn_corridor_model.t7
40 | hnn_path = ./pretrained_snn_model/hnn_corridor_model.t7
41 | model_saving_file_name = mhnn_corridor_v1
42 | device_id = 7
43 |
44 |
45 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 The authors of the paper 'Brain-inspired multimodal hybrid neural network for robot place recognition'
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Brain-inspired multimodal hybrid neural network for robot place recognition
2 |
3 | ## Table of Contents
4 | 1. About
5 | 2. Structure
6 | 3. Usage
7 | 4. Datasets
8 | 5. Citation
9 | 6. License
10 | 7. Acknowledgments
11 | 8. Note
12 |
13 | ## About
14 | This library implements a brain-inspired multimodal hybrid neural network (MHNN) for robot place recognition. The MHNN encodes and integrates multimodal cues from both conventional and neuromorphic sensors. Specifically, to encode different sensory cues, we build various neural networks of spatial view cells, place cells, head direction cells, and time cells. To integrate these cues, we design a multiscale liquid state machine that can process and fuse multimodal information effectively and asynchronously by using diverse neuronal dynamics and bio-inspired inhibitory circuits.
15 |
16 | ## Structure
17 | > * src
18 | >> * model: contains the MHNN model
19 | >> * tools: contains the utils
20 | >> * config: contains the configure files
21 | >> * main: contains the demo
22 |
23 | ## Usage
24 | * Step 1, setup the running environments
25 | * Step 2, download the datasets
26 | * Step 3, download the code:
27 | git clone https://github.com/cognav/neurogpr.git
28 | * Step 4, run the demo in the folder of ‘main’ :
29 | python main_mhnn.py --config_file ../config/corridor_setting.ini
30 |
31 | ## Requirements
32 | The following libraries are needed for running the demo.
33 | * Python 3.7.4
34 | * Torch 1.11.0
35 | * Torchvision 0.12.0
36 | * Numpy 1.21.5
37 | * Scipy 1.7.1
38 |
39 | ## Datasets
40 | The datasets include four groups of data collected in different environments.
41 | * The Room, Corridor, and THU-Forest datasets are available on Zenodo (https://zenodo.org/record/7827108#.ZD_ke3bP0ds).
42 | * The public Brisbane-Event-VPR dataset is available on Zenodo (https://zenodo.org/record/4302805#.ZD8puXbP0ds).
43 |
44 | ## Citation
45 | Fangwen Yu, Yujie Wu, Songchen Ma, Mingkun Xu, Hongyi Li, Huanyu Qu, Chenhang Song, Taoyi Wang, Rong Zhao and Luping Shi. Brain-inspired multimodal hybrid neural network for robot place recognition. Science Robotics (accepted).
46 |
47 | ## License
48 | MIT License
49 |
50 | ## Acknowledgments
51 | If you have any questions, please contact us.
52 |
53 | ## Note
54 | If the parameter settings are different, the results may be not consistent. You may need to modify the settings or adjust the code properly.
55 |
56 |
--------------------------------------------------------------------------------
/src/model/snn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | thresh = 0.15
6 | lens = 0.5
7 | probs = 0.5
8 | decay = 0.6
9 |
10 | cfg_cnn = [(2, 64, 3, 1, 3),
11 | (64, 128, 2, 1, 3),
12 | (128, 256, 2, 1, 3),
13 | (256, 256, 1, 1, 3),
14 | ]
15 | cfg_fc = [512, 512, 200]
16 |
17 | def mem_update(opts, x, mem, spike):
18 | mem = mem * decay * (1 - spike) + opts(x)
19 | spike = act_fun(mem - thresh)
20 | return mem, spike
21 |
22 | class ActFun(torch.autograd.Function):
23 |
24 | @staticmethod
25 | def forward(ctx, input):
26 | ctx.save_for_backward(input)
27 | return input.gt(0).float()
28 |
29 | @staticmethod
30 | def backward(ctx, grad_output):
31 | input, = ctx.saved_tensors
32 | grad_input = grad_output.clone()
33 | temp = abs(input) np.pi: # set up connectivity matrix on a ring
110 | # d = 2*np.pi-d
111 | # self.w[i, j] = self.J0 * np.exp(-0.5 * d) / (np.sqrt(2 * np.pi) * self.a)
112 | #
113 | # def data_reverse_transform(self, x):
114 | # x_min, x_max = np.min(x, axis=0), np.max(x, axis=0)
115 | # x_range = x_max - x_min
116 | # return (x - x_min) / (1e-11 + x_range) * self.z_range + self.z_min
117 | #
118 | # def get_stimulus_by_pos(self, x):
119 | # d = (x - self.centers) / (self.z_range + 1e-4)
120 | # y = np.exp(-1 * np.square(d / 0.5))
121 | # return y
122 | #
123 | # def update(self, data, trajactory_mode=False):
124 | # seq_len, seq_dim = data.shape
125 | # u = np.zeros((self.num, seq_dim))
126 | # u_record = np.zeros((seq_len, self.num, seq_dim))
127 | #
128 | # for i in range(0, seq_len):
129 | # if i == 0:
130 | # cur_stimulus = self.get_stimulus_by_pos(data[0])
131 | # else:
132 | # cur_stimulus = self.get_stimulus_by_pos(data[i - 1])
133 | # t = np.arange(self.I_dur * i, self.I_dur * i + self.I_dur, self.dt)
134 | # cur_du = odeint(int_u, u.flatten(), t,
135 | # args=(self.w, cur_stimulus.t().numpy(), self.tau, self.Inh_inp, self.k, self.dt))
136 | #
137 | # u = cur_du[-1].reshape(-1, seq_dim)
138 | # r1 = np.square(u)
139 | # r2 = 1.0 + 0.5 * self.k * np.sum(r1, axis=0)
140 | # u_record[i] = r1 / r2.reshape(1, -1)
141 | #
142 | # out = r1 / r2
143 | # if trajactory_mode:
144 | # out = u_record
145 | #
146 | # return out
147 |
148 |
149 | # class CANN2D():
150 | # def __init__(self, data, length = 128, tau = 0.02, Inh_inp = -.1, I_dur = 0.4, dt = 0.05, A = 20, k = 1):
151 | # self.length = length # 2D-CANN :(length, length)
152 | # self.tau = tau # key parameters: control the pred curve
153 | # self.Inh_inp = Inh_inp # key parameters: inhibitive stimulus, key parameters
154 | # self.I_dur = I_dur
155 | # self.dt = dt
156 | # self.A = A
157 | # self.k = 1
158 | # self.t_wins = data.shape[0] + 1
159 | # self.n_units = data.shape[1]
160 | # self.z_min, self.z_max = np.min(data), np.max(data) + 1e-7
161 | # self.z_range = self.z_max - self.z_min
162 | # self.x = np.linspace(self.z_min, self.z_max, length) # sample centers
163 | # self.dx = self.z_range / length
164 | #
165 | # self.a = 2
166 | # self.J0 = 4
167 | # self.u = [] # membrane
168 | # self.u.append(np.zeros((length,length)))
169 | # self.r = [] # firing rate
170 | # self.r.append(np.zeros((length,length)))
171 | # self.conn_mat = self.make_conn()
172 | #
173 | # def dist(self,d):
174 | # v_size = np.asarray([self.z_range,self.z_range])
175 | # return np.where(d > v_size/2, v_size -d ,d)
176 | #
177 | # def make_conn(self):
178 | # x1, x2 = np.meshgrid(self.x/self.z_max,self.x/self.z_max)
179 | # value = np.stack([x1.flatten(), x2.flatten()]).T
180 | #
181 | # dd = self.dist(np.abs(value[0] - value))
182 | # d = np.linalg.norm(dd, ord=1, axis=1)
183 | # d = d.reshape((self.length, self.length))
184 | # Jxx = self.J0 * np.exp(-0.5 * np.square(d/self.a))/(np.sqrt(2*np.pi)*self.a)
185 | # return Jxx
186 | #
187 | #
188 | # def data_reverse_transform(self, x):
189 | # x_min, x_max = np.min(x, axis=0),np.max(x, axis=0)
190 | # x_range = x_max - x_min
191 | # return (x - x_min)/(1e-11+x_range) * self.z_range + self.z_min
192 | #
193 | # def get_stimulus_by_pos(self, x):
194 | # # code external stimilus by Guassian coding
195 | # x1, x2 = np.meshgrid(self.x, self.x)
196 | # value = np.stack([x1.flatten(), x2.flatten()]).T
197 | # # dd = self.dist(np.abs(np.asarray(x)-value))
198 | # dd = np.abs(np.asarray(x) - value)
199 | # d = np.linalg.norm(dd, axis=1)
200 | # d = d.reshape((self.length, self.length))
201 | # return self.A * np.exp(-0.25*np.square(d/2))
202 | #
203 | #
204 | # def update(self, data, trajactory_mode = False):
205 | # # convert the data input a 3D tensor: seq_len * dim * dim
206 | # seq_len, seq_dim = data.shape
207 | # u = np.zeros((self.length, self.length))
208 | # u_record = np.zeros((seq_len, self.length, self.length))
209 | #
210 | # for i in range(1, seq_len):
211 | # # generate the 2D inputs
212 | # cur_stimulus = self.get_stimulus_by_pos(data[i - 1])
213 | # t = np.arange(self.I_dur * i, self.I_dur * i + self.I_dur, self.dt)
214 | # # perform updates
215 | # cur_du = odeint(int_u, u.flatten(), t,
216 | # args=(self.conn_mat, cur_stimulus.T,
217 | # self.tau, self.Inh_inp, self.k, self.dt))
218 | # u = cur_du[-1].reshape(-1, self.length)
219 | # r1 = np.square(u)
220 | # r2 = 1.0 + self.k * np.sum(r1)
221 | # u_record[i] = r1/r2
222 | #
223 | #
224 | # # output the matrix and down-sampling
225 | # out = r1/r2
226 | # if trajactory_mode == True:
227 | # out = u_record
--------------------------------------------------------------------------------
/src/tools/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch
4 | import torch.utils.model_zoo
5 | from torch.utils.data import Dataset
6 | import torchvision
7 | import math
8 | import numpy as np
9 | import random
10 | from sklearn.metrics import precision_recall_curve
11 |
12 | def least_common_multiple(num):
13 | mini = 1
14 | for i in num:
15 | mini = int(i) * int(mini) /math.gcd(int(i), mini)
16 | mini = int(mini)
17 | return mini
18 |
19 | class SequentialDataset(Dataset):
20 | def __init__(self, data_dir, exp_idx, transform, nclass=None, seq_len_aps=None, seq_len_dvs = None,
21 | seq_len_gps=None, seq_len_head = None, seq_len_time = None):
22 | self.data_dir = data_dir
23 | self.transform = transform
24 | self.num_exp = len(exp_idx)
25 | self.exp_idx = exp_idx
26 |
27 | self.total_aps = [np.sort(os.listdir(data_dir + str(idx) + '/dvs_frames')) for idx in exp_idx]
28 | self.total_dvs = [np.sort(os.listdir(data_dir + str(idx) + '/dvs_7ms_3seq')) for idx in exp_idx]
29 | self.num_imgs = [len(x) for x in self.total_aps]
30 | self.raw_pos = [np.loadtxt(data_dir + str(idx) + '/position.txt', delimiter=' ') for idx in exp_idx]
31 | self.raw_head = [np.loadtxt(data_dir + str(idx) + '/direction.txt', delimiter=' ') for idx in exp_idx]
32 |
33 | self.t_pos = [x[:, 0] for x in self.raw_pos]
34 | self.t_aps = [[float(x[:-4]) for x in y] for y in self.total_aps]
35 | self.t_dvs = [[float(x[:-4]) for x in y] for y in self.total_dvs]
36 | self.data_pos = [idx[:, 0:3]-idx[:, 0:3].min(axis=0) for idx in self.raw_pos]
37 | self.data_head = [idx[:, 0:3] - idx[:, 0:3].min(axis=0) for idx in
38 | self.raw_head]
39 | self.seq_len_aps = seq_len_aps
40 | self.seq_len_gps = seq_len_gps
41 | self.seq_len_dvs = seq_len_dvs
42 | self.seq_len_head = seq_len_head
43 | self.seq_len_time = seq_len_time
44 | self.seq_len = max(seq_len_gps, seq_len_aps)
45 | self.nclass = nclass
46 |
47 | self.lens = len(self.total_aps) - self.seq_len
48 | self.dvs_data = None
49 | self.duration = [x[-1] - x[0] for x in self.t_dvs]
50 |
51 | nums = 1e5
52 | for x in self.total_aps:
53 | if len(x) < nums: nums = len(x)
54 | for x in self.total_dvs:
55 | if len(x) < nums: nums = len(x)
56 | for x in self.raw_pos:
57 | if len(x) < nums: nums = len(x)
58 |
59 | self.lens = nums
60 |
61 | def __len__(self):
62 | return self.lens - self.seq_len * 2
63 |
64 | def __getitem__(self, idx):
65 | exp_index = np.random.randint(self.num_exp)
66 | idx = max(min(idx, self.num_imgs[exp_index] - self.seq_len * 2), self.seq_len_dvs * 3)
67 | img_seq = []
68 | for i in range(self.seq_len_aps):
69 | img_loc = self.data_dir + str(self.exp_idx[exp_index]) + '/dvs_frames/' + \
70 | self.total_aps[exp_index][idx - self.seq_len_aps + i]
71 | img_seq += [Image.open(img_loc).convert('RGB')]
72 | img_seq_pt = []
73 |
74 | if self.transform:
75 | for images in img_seq:
76 | img_seq_pt += [torch.unsqueeze(self.transform(images), 0)]
77 |
78 | img_seq = torch.cat(img_seq_pt, dim=0)
79 | t_stamps = self.raw_pos[exp_index][:, 0]
80 | t_target = self.t_aps[exp_index][idx]
81 |
82 | idx_pos = max(np.searchsorted(t_stamps, t_target), self.seq_len_aps)
83 | pos_seq = self.data_pos[exp_index][idx_pos - self.seq_len_gps:idx_pos, :]
84 | pos_seq = torch.from_numpy(pos_seq.astype('float32'))
85 |
86 | t_stamps = self.raw_head[exp_index][:, 0]
87 | t_target = self.t_aps[exp_index][idx]
88 | idx_head = max(np.searchsorted(t_stamps, t_target), self.seq_len_aps)
89 | head_seq = self.data_head[exp_index][idx_head - self.seq_len_gps:idx_pos, 1]
90 | head_seq = torch.from_numpy(head_seq.astype('float32')).reshape(-1,1)
91 |
92 | idx_dvs = np.searchsorted(self.t_dvs[exp_index], t_target, sorter=None) - 1
93 | t_stamps = self.t_dvs[exp_index][idx_dvs]
94 | dvs_seq = torch.zeros(self.seq_len_dvs * 3, 2, 130, 173)
95 | for i in range(self.seq_len_dvs):
96 | dvs_path = self.data_dir + str(self.exp_idx[exp_index]) + '/dvs_7ms_3seq/' \
97 | + self.total_dvs[exp_index][idx_dvs - self.seq_len_dvs + i + 1]
98 | dvs_buf = torch.load(dvs_path)
99 | dvs_buf = dvs_buf.permute([1, 0, 2, 3])
100 | dvs_seq[i * 3: (i + 1) * 3] = torch.nn.functional.avg_pool2d(dvs_buf, 2)
101 | ids = int((t_stamps - self.t_dvs[exp_index][0]) / self.duration[exp_index] * self.nclass)
102 |
103 | ids = np.clip(ids, a_min=0, a_max= self.nclass - 1)
104 | ids = np.array(ids)
105 | ids = torch.from_numpy(ids).type(torch.long)
106 | return (img_seq, pos_seq, dvs_seq, head_seq), ids
107 |
108 | def Data(data_path=None, batch_size=None, exp_idx=None, is_shuffle=True, normalize = None, nclass=None, seq_len_aps=None,
109 | seq_len_dvs=None, seq_len_gps=None, seq_len_head = None, seq_len_time = None):
110 | dataset = SequentialDataset(data_dir=data_path,
111 | exp_idx=exp_idx,
112 | transform=torchvision.transforms.Compose([
113 | torchvision.transforms.Resize((240, 320)),
114 | torchvision.transforms.ToTensor(),
115 | normalize,
116 | ]),
117 | nclass=nclass,
118 | seq_len_aps=seq_len_aps,
119 | seq_len_dvs=seq_len_dvs,
120 | seq_len_gps=seq_len_gps,
121 | seq_len_head = seq_len_head,
122 | seq_len_time = seq_len_time)
123 |
124 | data_loader = torch.utils.data.DataLoader(dataset,
125 | shuffle=is_shuffle,
126 | drop_last=True,
127 | batch_size=batch_size,
128 | pin_memory=True,
129 | num_workers=8)
130 | return data_loader
131 |
132 | def Data_brightness(data_path=None, batch_size=None, exp_idx=None, is_shuffle=True, normalize = None, nclass=None, seq_len_aps=None,
133 | seq_len_dvs=None, seq_len_gps=None, seq_len_head = None, seq_len_time = None):
134 | dataset = SequentialDataset(data_dir=data_path,
135 | exp_idx=exp_idx,
136 | transform=torchvision.transforms.Compose([
137 | torchvision.transforms.Resize((240, 320)),
138 | torchvision.transforms.ToTensor(),
139 | normalize,
140 | torchvision.transforms.ColorJitter(brightness=0.5),
141 | ]),
142 | nclass=nclass,
143 | seq_len_aps=seq_len_aps,
144 | seq_len_dvs=seq_len_dvs,
145 | seq_len_gps=seq_len_gps,
146 | seq_len_head=seq_len_head,
147 | seq_len_time=seq_len_time)
148 |
149 | data_loader = torch.utils.data.DataLoader(dataset,
150 | shuffle=is_shuffle,
151 | drop_last=True,
152 | batch_size=batch_size,
153 | pin_memory=True,
154 | num_workers=8)
155 | return data_loader
156 |
157 | def Data_mask(data_path=None, batch_size=None, exp_idx=None, is_shuffle=True, normalize = None, nclass=None, seq_len_aps=None,
158 | seq_len_dvs=None, seq_len_gps=None, seq_len_head = None, seq_len_time = None):
159 | dataset = SequentialDataset(data_dir=data_path,
160 | exp_idx=exp_idx,
161 | transform=torchvision.transforms.Compose([
162 | torchvision.transforms.Resize((240, 320)),
163 | torchvision.transforms.ToTensor(),
164 | normalize,
165 | torchvision.transforms.RandomCrop(size=128, padding=128),
166 | ]),
167 | nclass=nclass,
168 | seq_len_aps=seq_len_aps,
169 | seq_len_dvs=seq_len_dvs,
170 | seq_len_gps=seq_len_gps,
171 | seq_len_head=seq_len_head,
172 | seq_len_time=seq_len_time)
173 |
174 | data_loader = torch.utils.data.DataLoader(dataset,
175 | shuffle=is_shuffle,
176 | drop_last=True,
177 | batch_size=batch_size,
178 | pin_memory=True,
179 | num_workers=8)
180 | return data_loader
181 |
182 | def setup_seed(seed):
183 | torch.manual_seed(seed)
184 | torch.cuda.manual_seed(seed)
185 | np.random.seed(seed)
186 | random.seed(seed)
187 | torch.backends.cudnn.deterministic = True
188 |
189 | def accuracy(output, target, topk=(1,)):
190 | with torch.no_grad():
191 | _, pred = output.topk(max(topk), dim=1, largest=True, sorted=True)
192 | pred = pred.t()
193 | correct = pred.eq(target.view(1, -1).expand_as(pred))
194 | res = [correct[:k].sum().item() for k in topk]
195 | return res
196 |
197 | def compute_matches(retrieved_all, ground_truth_info):
198 | matches=[]
199 | itr=0
200 | for retr in retrieved_all:
201 | if (retr == ground_truth_info[itr]):
202 | matches.append(1)
203 | else:
204 | matches.append(0)
205 | itr=itr+1
206 | return matches
207 |
208 | def compute_precision_recall(matches,scores_all):
209 | precision, recall, _ = precision_recall_curve(matches, scores_all)
210 | return precision, recall
211 |
--------------------------------------------------------------------------------
/src/main/main_mhnn.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.utils.model_zoo
3 | import time
4 | from src.model.mhnn import *
5 | from src.tools.utils import *
6 | from src.config.config_utils import *
7 |
8 | import argparse
9 | import logging
10 | logging.basicConfig(level=logging.INFO)
11 | log = logging.getLogger()
12 |
13 | setup_seed(0)
14 |
15 | parser = argparse.ArgumentParser(description='mhnn', argument_default=argparse.SUPPRESS)
16 | parser.add_argument('--config_file', type=str, required=True, help='Configure file path')
17 |
18 | def main_mhnn(options: argparse.Namespace):
19 | config = get_config(options.config_file, logger=log)['main']
20 | config = update_config(config, options.__dict__)
21 |
22 | cnn_arch = str(config['cnn_arch'])
23 | num_epoch = int(config['num_epoch'])
24 | batch_size = int(config['batch_size'])
25 | num_class = int(config['num_class'])
26 | rnn_num = int(config['rnn_num'])
27 | cann_num = int(config['cann_num'])
28 | reservoir_num = int(config['reservoir_num'])
29 | num_iter = int(config['num_iter'])
30 | spiking_threshold = float(config['spiking_threshold'])
31 | sparse_lambdas = int(config['sparse_lambdas'])
32 | lr = float(config['lr'])
33 | r = float(config['r'])
34 |
35 | ann_pre_load = get_bool_from_config(config, 'ann_pre_load')
36 | snn_pre_load = get_bool_from_config(config, 'snn_pre_load')
37 | re_trained = get_bool_from_config(config, 're_trained')
38 |
39 | seq_len_aps = int(config['seq_len_aps'])
40 | seq_len_gps = int(config['seq_len_gps'])
41 | seq_len_dvs = int(config['seq_len_dvs'])
42 | seq_len_head = int(config['seq_len_head'])
43 | seq_len_time = int(config['seq_len_time'])
44 |
45 | dvs_expand = int(config['dvs_expand'])
46 | expand_len = least_common_multiple([seq_len_aps, seq_len_dvs * dvs_expand, seq_len_gps])
47 |
48 | test_exp_idx = []
49 | for idx in config['test_exp_idx']:
50 | if idx != ',':
51 | test_exp_idx.append(int(idx))
52 |
53 | train_exp_idx = []
54 | for idxt in config['train_exp_idx']:
55 | if idxt != ',':
56 | train_exp_idx.append(int(idxt))
57 |
58 | data_path = str(config['data_path'])
59 | snn_path = str(config['snn_path'])
60 | hnn_path = str(config['hnn_path'])
61 | model_saving_file_name = str(config['model_saving_file_name'])
62 |
63 | w_fps = int(config['w_fps'])
64 | w_gps = int(config['w_gps'])
65 | w_dvs = int(config['w_dvs'])
66 | w_head = int(config['w_head'])
67 | w_time = int(config['w_time'])
68 |
69 | device_id = str(config['device_id'])
70 |
71 | normalize = torchvision.transforms.Normalize(mean=[0.3537, 0.3537, 0.3537],
72 | std=[0.3466, 0.3466, 0.3466])
73 |
74 | train_loader = Data(data_path, batch_size=batch_size, exp_idx=train_exp_idx, is_shuffle=True,
75 | normalize=normalize, nclass=num_class,
76 | seq_len_aps=seq_len_aps, seq_len_dvs=seq_len_dvs, seq_len_gps=seq_len_gps,
77 | seq_len_head=seq_len_head, seq_len_time = seq_len_time)
78 |
79 | test_loader = Data(data_path, batch_size=batch_size, exp_idx=test_exp_idx, is_shuffle=True,
80 | normalize=normalize, nclass=num_class,
81 | seq_len_aps=seq_len_aps, seq_len_dvs=seq_len_dvs, seq_len_gps=seq_len_gps,
82 | seq_len_head=seq_len_head, seq_len_time = seq_len_time)
83 |
84 | mhnn = MHNN(device = device,
85 | cnn_arch = cnn_arch,
86 | num_epoch = num_epoch,
87 | batch_size = batch_size,
88 | num_class = num_class,
89 | rnn_num = rnn_num,
90 | cann_num = cann_num,
91 | reservoir_num = reservoir_num,
92 | spiking_threshold = spiking_threshold,
93 | sparse_lambdas = sparse_lambdas,
94 | r = r,
95 | lr = lr,
96 | w_fps = w_fps,
97 | w_gps = w_gps,
98 | w_dvs = w_dvs,
99 | w_head = w_head,
100 | w_time = w_time,
101 | seq_len_aps = seq_len_aps,
102 | seq_len_gps = seq_len_gps,
103 | seq_len_dvs = seq_len_dvs,
104 | seq_len_head = seq_len_head,
105 | seq_len_time = seq_len_time,
106 | dvs_expand = dvs_expand,
107 | expand_len = expand_len,
108 | train_exp_idx = train_exp_idx,
109 | test_exp_idx = test_exp_idx,
110 | data_path = data_path,
111 | snn_path = snn_path,
112 | hnn_path = hnn_path,
113 | num_iter = num_iter,
114 | ann_pre_load = ann_pre_load,
115 | snn_pre_load = snn_pre_load,
116 | re_trained = re_trained)
117 |
118 | mhnn.cann_init(np.concatenate((train_loader.dataset.data_pos[0],train_loader.dataset.data_head[0][:,1].reshape(-1,1)),axis=1))
119 |
120 | mhnn.to(device)
121 |
122 | optimizer = torch.optim.Adam(mhnn.parameters(), lr)
123 |
124 | lr_schedule = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
125 |
126 | criterion = nn.CrossEntropyLoss()
127 |
128 | record = {}
129 | record['loss'], record['top1'], record['top5'], record['top10'] = [], [], [], []
130 | best_test_acc1, best_test_acc5, best_recall, best_test_acc10 = 0., 0., 0, 0
131 |
132 | train_iters = iter(train_loader)
133 | iters = 0
134 | start_time = time.time()
135 | print(device)
136 |
137 | import torchmetrics
138 |
139 | best_recall = 0.
140 | test_acc = torchmetrics.Accuracy()
141 |
142 | test_recall = torchmetrics.Recall(average='none', num_classes=num_class)
143 | test_precision = torchmetrics.Precision(average='none', num_classes=num_class)
144 |
145 | for epoch in range(num_epoch):
146 |
147 | ###############
148 | ## for training
149 | ###############
150 |
151 | running_loss = 0.
152 | counts = 1.
153 | acc1_record, acc5_record, acc10_record = 0., 0., 0.
154 | while iters < num_iter:
155 | mhnn.train()
156 | optimizer.zero_grad()
157 |
158 | try:
159 | inputs, target = next(train_iters)
160 | except StopIteration:
161 | train_iters = iter(train_loader)
162 | inputs, target = next(train_iters)
163 |
164 | outputs, outs = mhnn(inputs, epoch=epoch)
165 |
166 | class_loss = criterion(outputs, target.to(device))
167 |
168 | sparse_loss = 0.
169 | w_module = (w_fps, w_dvs, w_gps, w_gps)
170 | for (i, l) in enumerate(outs):
171 | sparse_loss += (l.mean() - r)**2 * w_module[i]
172 |
173 | loss = class_loss + sparse_lambdas * sparse_loss
174 |
175 | loss.backward()
176 |
177 | optimizer.step()
178 |
179 | running_loss += loss.cpu().item()
180 |
181 | acc1, acc5, acc10 = accuracy(outputs.cpu(), target, topk=(1, 5, 10))
182 | acc1, acc5, acc10 = acc1 / len(outputs), acc5 / len(outputs), acc10 / len(outputs)
183 |
184 | acc1_record += acc1
185 | acc5_record += acc5
186 | acc10_record += acc10
187 |
188 | counts += 1
189 | iters += 1.
190 | iters = 0
191 |
192 | record['loss'].append(loss.item())
193 | print('\n\nTime elaspe:', time.time() - start_time, 's')
194 | print(
195 | 'Training epoch %.1d, training loss :%.4f, sparse loss :%.4f, training Top1 acc: %.4f, training Top5 acc: %.4f' %
196 | (epoch, running_loss / (num_iter), sparse_lambdas * sparse_loss, acc1_record / counts, acc5_record / counts))
197 |
198 | lr_schedule.step()
199 | start_time = time.time()
200 |
201 | ##############
202 | ## for testing
203 | ##############
204 |
205 | running_loss = 0.
206 | mhnn.eval()
207 |
208 | with torch.no_grad():
209 |
210 | acc1_record, acc5_record, acc10_record = 0., 0., 0.
211 | counts = 1.
212 |
213 | for batch_idx, (inputs, target) in enumerate(test_loader):
214 |
215 | outputs, _ = mhnn(inputs, epoch=epoch)
216 |
217 | loss = criterion(outputs.cpu(), target)
218 |
219 | running_loss += loss.item()
220 | acc1, acc5, acc10 = accuracy(outputs.cpu(), target, topk=(1, 5, 10))
221 | acc1, acc5, acc10 = acc1 / len(outputs), acc5 / len(outputs), acc10 / len(outputs)
222 | acc1_record += acc1
223 | acc5_record += acc5
224 | acc10_record += acc10
225 |
226 | counts += 1
227 | outputs = outputs.cpu()
228 |
229 | test_acc(outputs.argmax(1), target)
230 | test_recall(outputs.argmax(1), target)
231 | test_precision(outputs.argmax(1), target)
232 |
233 | compute_num = 100
234 | if batch_idx % compute_num == 0 and batch_idx > 0:
235 | #print('time:', (time.time() - start_time) / compute_num)
236 | start_time = time.time()
237 |
238 | total_acc = test_acc.compute().mean()
239 | total_recall = test_recall.compute().mean()
240 | total_precison = test_precision.compute().mean()
241 |
242 | test_precision.reset()
243 | test_recall.reset()
244 | test_acc.reset()
245 |
246 | acc1_record = acc1_record / counts
247 | acc5_record = acc5_record / counts
248 | acc10_record = acc10_record / counts
249 |
250 | record['top1'].append(acc1_record)
251 | record['top5'].append(acc5_record)
252 | record['top10'].append(acc10_record)
253 |
254 |
255 | print('Current best Top1: ', best_test_acc1, 'Best Top5:', best_test_acc5)
256 |
257 | if epoch > 4:
258 | if best_test_acc1 < acc1_record:
259 | best_test_acc1 = acc1_record
260 | print('Achiving the best Top1, saving...', best_test_acc1)
261 |
262 | if best_test_acc5 < acc5_record:
263 | # best_test_acc1 = acc1_record
264 | best_test_acc5 = acc5_record
265 | print('Achiving the best Top5, saving...', best_test_acc5)
266 |
267 | if best_recall < total_recall:
268 | # best_test_acc1 = acc1_record
269 | best_recall = total_recall
270 | print('Achiving the best recall, saving...', best_recall)
271 |
272 | if best_test_acc10 < acc10_record:
273 | best_test_acc10 = acc10_record
274 |
275 | state = {
276 | 'net': mhnn.state_dict(),
277 | 'snn': mhnn.snn.state_dict(),
278 | 'record': record,
279 | 'best_recall': best_recall,
280 | 'best_acc1': best_test_acc1,
281 | 'best_acc5': best_test_acc5,
282 | 'best_acc10': best_test_acc10
283 | }
284 | if not os.path.isdir('../../checkpoint'):
285 | os.mkdir('../../checkpoint')
286 | torch.save(state, '../../checkpoint/' + model_saving_file_name + '.t7')
287 |
288 |
289 | if __name__ == '__main__':
290 | options, unknowns = parser.parse_known_args()
291 | main_mhnn(options)
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
--------------------------------------------------------------------------------
/src/model/mhnn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.model_zoo
4 | import torchvision.models as models
5 | from src.model.snn import *
6 | from src.model.cann import CANN
7 | from src.tools.utils import *
8 |
9 | # gpu environment
10 | #os.environ['CUDA_VISIBLE_DEVICES'] = "3"
11 | device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
12 |
13 | class MHNN(nn.Module):
14 | def __init__(self, **kwargs):
15 | super(MHNN, self).__init__()
16 |
17 | self.cnn_arch = kwargs.get('cnn_arch')
18 | self.num_class = kwargs.get('num_class')
19 | self.cann_num = kwargs.get('cann_num')
20 | self.rnn_num = kwargs.get('rnn_num')
21 | self.lr = kwargs.get('lr')
22 | self.batch_size = kwargs.get('batch_size')
23 | self.sparse_lambdas = kwargs.get('sparse_lambdas')
24 | self.r = kwargs.get('r')
25 |
26 | self.reservoir_num = kwargs.get('reservoir_num')
27 | self.threshold = kwargs.get('spiking_threshold')
28 |
29 | self.num_epoch = kwargs.get('num_epoch')
30 | self.num_iter = kwargs.get('num_iter')
31 | self.w_fps = kwargs.get('w_fps')
32 | self.w_gps = kwargs.get('w_gps')
33 | self.w_dvs = kwargs.get('w_dvs')
34 | self.w_head = kwargs.get('w_head')
35 | self.w_time = kwargs.get('w_time')
36 |
37 | self.seq_len_aps = kwargs.get('seq_len_aps')
38 | self.seq_len_gps = kwargs.get('seq_len_gps')
39 | self.seq_len_dvs = kwargs.get('seq_len_dvs')
40 | self.seq_len_head = kwargs.get('seq_len_head')
41 | self.seq_len_time = kwargs.get('seq_len_time')
42 | self.dvs_expand = kwargs.get('dvs_expand')
43 |
44 | self.ann_pre_load = kwargs.get('ann_pre_load')
45 | self.snn_pre_load = kwargs.get('snn_pre_load')
46 | self.re_trained = kwargs.get('re_trained')
47 |
48 | self.train_exp_idx = kwargs.get('train_exp_idx')
49 | self.test_exp_idx = kwargs.get('test_exp_idx')
50 |
51 | self.data_path = kwargs.get('data_path')
52 | self.snn_path = kwargs.get('snn_path')
53 |
54 | #self.device = kwargs.get('device')
55 | self.device = device
56 |
57 | if self.ann_pre_load:
58 | print("=> Loading pre-trained model '{}'".format(self.cnn_arch))
59 | self.cnn = models.__dict__[self.cnn_arch](pretrained=self.ann_pre_load)
60 | else:
61 | print("=> Using randomly inizialized model '{}'".format(self.cnn_arch))
62 | self.cnn = models.__dict__[self.cnn_arch](pretrained=self.ann_pre_load)
63 |
64 | if self.cnn_arch == "mobilenet_v2":
65 | """ MobileNet """
66 | self.feature_dim = self.cnn.classifier[1].in_features
67 | self.cnn.classifier[1] = nn.Identity()
68 |
69 | elif self.cnn_arch == "resnet50":
70 | """ Resnet50 """
71 |
72 | self.feature_dim = 512
73 |
74 | # self.cnn.layer1 = nn.Identity()
75 | self.cnn.layer2 = nn.Identity()
76 | self.cnn.layer3 = nn.Identity()
77 | self.cnn.layer4 = nn.Identity()
78 | self.cnn.fc = nn.Identity()
79 |
80 | self.cnn.layer1[1] = nn.Identity()
81 | self.cnn.layer1[2] = nn.Identity()
82 |
83 | # self.cnn.layer2[0] = nn.Identity()
84 | # self.cnn.layer2[0].conv2 = nn.Identity()
85 | # self.cnn.layer2[0].bn2 = nn.Identity()
86 |
87 | fc_inputs = 256
88 | self.cnn.fc = nn.Linear(fc_inputs,self.feature_dim)
89 |
90 | else:
91 | print("=> Please check model name or configure architecture for feature extraction only, exiting...")
92 | exit()
93 |
94 | for param in self.cnn.parameters():
95 | param.requires_grad = self.re_trained
96 |
97 | #############
98 | # SNN module
99 | #############
100 | self.snn = SNN(device = self.device).to(self.device)
101 | self.snn_out_dim = self.snn.fc2.weight.size()[1]
102 | self.ann_out_dim = self.feature_dim
103 | self.cann_out_dim = 4 * self.cann_num
104 | self.reservior_inp_num = self.ann_out_dim + self.snn_out_dim + self.cann_out_dim
105 | self.LN = nn.LayerNorm(self.reservior_inp_num)
106 | if self.snn_pre_load:
107 | self.snn.load_state_dict(torch.load(self.snn_path)['snn'])
108 |
109 | #############
110 | # CANN module
111 | #############
112 | self.cann_num = self.cann_num
113 | self.cann = None
114 | self.num_class = self.num_class
115 |
116 | #############
117 | # MLSM module
118 | #############
119 | self.input_size = self.feature_dim
120 | self.reservoir_num = self.reservoir_num
121 |
122 | self.threshold = 0.5
123 | self.decay = nn.Parameter(torch.rand(self.reservoir_num))
124 |
125 | self.K = 128
126 | self.num_block = 5
127 | self.num_blockneuron = int(self.reservoir_num / self.num_block)
128 |
129 | self.decay_scale = 0.5
130 | self.beta_scale = 0.1
131 |
132 | self.thr_base1 = self.threshold
133 |
134 | self.thr_beta1 = nn.Parameter(self.beta_scale * torch.rand(self.reservoir_num))
135 |
136 | self.thr_decay1 = nn.Parameter(self.decay_scale * torch.rand(self.reservoir_num))
137 |
138 | self.ref_base1 = self.threshold
139 | self.ref_beta1 = nn.Parameter(self.beta_scale * torch.rand(self.reservoir_num))
140 | self.ref_decay1 = nn.Parameter(self.decay_scale * torch.rand(self.reservoir_num))
141 |
142 | self.cur_base1 = 0
143 | self.cur_beta1 = nn.Parameter(self.beta_scale * torch.rand(self.reservoir_num))
144 | self.cur_decay1 = nn.Parameter(self.decay_scale * torch.rand(self.reservoir_num))
145 |
146 | self.project = nn.Linear(self.reservior_inp_num, self.reservoir_num)
147 |
148 | self.project_mask_matrix = torch.zeros((self.reservior_inp_num, self.reservoir_num))
149 |
150 |
151 | input_node_list = [0, self.ann_out_dim, self.snn_out_dim, self.cann_num * 2, self.cann_num]
152 |
153 | input_cum_list = np.cumsum(input_node_list)
154 |
155 | for i in range(len(input_cum_list) - 1):
156 | self.project_mask_matrix[input_cum_list[i]:input_cum_list[i + 1],
157 | self.num_blockneuron * i:self.num_blockneuron * (i + 1)] = 1
158 |
159 | self.project.weight.data = self.project.weight.data * self.project_mask_matrix.t()
160 |
161 | self.lateral_conn = nn.Linear(self.reservoir_num, self.reservoir_num)
162 |
163 | # control the ratio of # of lateral conn.
164 | self.lateral_conn_mask = torch.rand(self.reservoir_num, self.reservoir_num) > 0.8
165 | # remove self-recurrent conn.
166 | self.lateral_conn_mask = self.lateral_conn_mask * (1 - torch.eye(self.reservoir_num, self.reservoir_num))
167 | # adjust the intial weight of conn.
168 | self.lateral_conn.weight.data = 1e-3 * self.lateral_conn.weight.data * self.lateral_conn_mask.T
169 |
170 | #############
171 | # readout module
172 | #############
173 |
174 | # self.mlp1 = nn.Linear(self.reservoir_num, 256)
175 | # self.mlp2 = nn.Linear(256, self.num_class)
176 | self.mlp = nn.Linear(self.reservoir_num, self.num_class)
177 |
178 | def cann_init(self, data):
179 | self.cann = CANN(data)
180 |
181 | def lr_initial_schedule(self, lrs=1e-3):
182 | hyper_param_list = ['decay',
183 | 'thr_beta1', 'thr_decay1',
184 | 'ref_beta1', 'ref_decay1',
185 | 'cur_beta1', 'cur_decay1']
186 | hyper_params = list(filter(lambda x: x[0] in hyper_param_list, self.named_parameters()))
187 | base_params = list(filter(lambda x: x[0] not in hyper_param_list, self.named_parameters()))
188 | hyper_params = [x[1] for x in hyper_params]
189 | base_params = [x[1] for x in base_params]
190 | optimizer = torch.optim.SGD(
191 | [
192 | {'params': base_params, 'lr': lrs},
193 | {'params': hyper_params, 'lr': lrs / 2},
194 | ], lr=lrs, momentum=0.9, weight_decay=1e-7
195 | )
196 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(self.num_epoch))
197 | return optimizer, scheduler
198 |
199 | def wta_mem_update(self, fc, fv, k, inputx, spike, mem, thr, ref, last_cur):
200 | # scaling_lateral_inputs = 0.1
201 | #state = fc(inputx) + fv(spike) * scaling_lateral_inputs
202 | state = fc(inputx) + fv(spike)
203 |
204 | mem = (mem - spike * ref) * self.decay + state + last_cur
205 | # ratio of # of winners
206 | q0 = (100 -20)/100
207 | mem = mem.reshape(self.batch_size, self.num_blockneuron, -1)
208 | nps = torch.quantile(mem, q=q0, keepdim=True, axis=1)
209 | # nps =mem.max(axis=1,keepdim=True)[0] - thr
210 |
211 | # allows that only winner can fire spikes
212 | mem = F.relu(mem - nps).reshape(self.batch_size, -1)
213 |
214 | # spiking firing function
215 | spike = act_fun(mem - thr)
216 | return spike.float(), mem
217 |
218 | def trace_update(self, trace, spike, decay):
219 | return trace * decay + spike
220 |
221 | def forward(self, inp, epoch=100):
222 | aps_inp = inp[0].to(self.device)
223 | gps_inp = inp[1]
224 | dvs_inp = inp[2].to(self.device)
225 | head_inp = inp[3].to(self.device)
226 |
227 | batch_size, seq_len, channel, w, h = aps_inp.size()
228 | if self.w_fps > 0:
229 | aps_inp = aps_inp.view(batch_size * seq_len, channel, w, h)
230 | aps_out = self.cnn(aps_inp)
231 | out1 = aps_out.reshape(batch_size, self.seq_len_aps, -1).permute([1, 0, 2])
232 | else:
233 | out1 = torch.zeros(self.seq_len_aps, batch_size, -1, device=self.device).to(torch.float32)
234 |
235 | if self.w_dvs > 0:
236 | out2 = self.snn(dvs_inp, out_mode='time')
237 | else:
238 | out2 = torch.zeros(self.seq_len_dvs * 3, batch_size, self.snn_out_dim, device=self.device).to(torch.float32)
239 |
240 | ### CANN module
241 |
242 |
243 | if self.w_gps + self.w_head + self.w_time > 0:
244 | gps_record = []
245 | for idx in range(batch_size):
246 | buf = self.cann.update(torch.cat((gps_inp[idx],head_inp[idx].cpu()),axis=1), trajactory_mode=True)
247 | gps_record.append(buf[None, :, :, :])
248 | gps_out = torch.from_numpy(np.concatenate(gps_record)).cuda()
249 | gps_out = gps_out.permute([1, 0, 2, 3]).reshape(self.seq_len_gps, batch_size, -1)
250 | else:
251 | gps_out = torch.zeros((self.seq_len_gps, batch_size, self.cann_out_dim), device=self.device)
252 |
253 | # A generic CANN module was used for rapid testing; CANN1D/2D are provided in cann.py
254 | out3 = gps_out[:, :, self.cann_num:self.cann_num * 3].to(self.device).to(torch.float32) # position
255 | out4 = gps_out[:, :, : self.cann_num].to(self.device).to(torch.float32) # time
256 | out5 = gps_out[:, :, - self.cann_num:].to(self.device).to(torch.float32) # direction
257 |
258 | out3 *= self.w_gps
259 | out4 *= self.w_time
260 | out5 *= self.w_head
261 |
262 | expand_len = int(self.seq_len_gps * self.seq_len_aps / np.gcd(self.seq_len_gps, self.seq_len_dvs) * self.dvs_expand)
263 | input_num = self.feature_dim + self.snn.fc2.weight.size()[1] + self.cann_num * self.dvs_expand
264 |
265 | r_spike = r_mem = r_sumspike = torch.zeros(batch_size, self.reservoir_num, device=self.device)
266 |
267 | thr_trace = torch.zeros(batch_size, self.reservoir_num, device=self.device)
268 | ref_trace = torch.zeros(batch_size, self.reservoir_num, device=self.device)
269 | cur_trace = torch.zeros(batch_size, self.reservoir_num, device=self.device)
270 |
271 | K_winner = self.K * (1 + np.clip(self.num_epoch - epoch, a_min=0, a_max=self.num_epoch) / self.num_epoch)
272 |
273 | out1_zeros = torch.zeros_like(out1[0], device=self.device)
274 | out3_zeros = torch.zeros_like(out3[0], device=self.device)
275 | out4_zeros = torch.zeros_like(out4[0], device=self.device)
276 | out5_zeros = torch.zeros_like(out5[0], device=self.device)
277 |
278 | self.project.weight.data = self.project.weight.data * self.project_mask_matrix.t().cuda()
279 | self.lateral_conn.weight.data = self.lateral_conn.weight.data * self.lateral_conn_mask.T.cuda()
280 |
281 | self.decay.data = torch.clamp(self.decay.data, min=0, max=1.)
282 | self.thr_decay1.data = torch.clamp(self.thr_decay1.data, min=0, max=1.)
283 | self.ref_decay1.data = torch.clamp(self.ref_decay1.data, min=0, max=1)
284 | self.cur_decay1.data = torch.clamp(self.cur_decay1.data, min=0, max=1)
285 |
286 | for step in range(expand_len):
287 |
288 | idx = step % 3
289 | # simulate multimodal information with different time scales
290 | if idx == 2:
291 | combined_input = torch.cat((out1[step // 3], out2[step], out3[step // 3], out4[step // 3], out5[step // 3]), axis=1)
292 | else:
293 | combined_input = torch.cat((out1_zeros, out2[step],out3_zeros, out4_zeros, out5_zeros), axis=1)
294 |
295 | thr = self.thr_base1 + thr_trace * self.thr_beta1
296 | # ref = self.ref_base1 + ref_trace * self.ref_beta1 # option: ref = self.ref_base1
297 | ref = self.ref_base1
298 | cur = self.cur_base1 + cur_trace * self.cur_beta1
299 |
300 | inputx = combined_input.float()
301 | r_spike, r_mem = self.wta_mem_update(self.project, self.lateral_conn, K_winner, inputx, r_spike, r_mem,
302 | thr, ref, cur)
303 | thr_trace = self.trace_update(thr_trace, r_spike, self.thr_decay1)
304 | ref_trace = self.trace_update(ref_trace, r_spike, self.ref_decay1)
305 | cur_trace = self.trace_update(cur_trace, r_spike, self.cur_decay1)
306 | r_sumspike = r_sumspike + r_spike
307 |
308 | # cat_out = F.dropout(r_sumspike, p=0.5, training=self.training)
309 | # out1 = self.mlp1(r_sumspike).relu()
310 | # out2 = self.mlp2(out1)
311 | out2 = self.mlp(r_sumspike)
312 |
313 | neuron_pop = r_sumspike.reshape(batch_size, -1, self.num_blockneuron).permute([1, 0, 2])
314 | return out2, (neuron_pop[0], neuron_pop[1])
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
--------------------------------------------------------------------------------