├── image ├── 1.txt ├── SNN_net.png ├── heatmap.png ├── waveform.png ├── workflow.png └── Wavelet Transform.png ├── Wavelet transform example chart ├── 1.txt ├── line_1.png ├── line_1017.png └── line_99.png ├── code ├── MTSA-SNN_dataloader.py ├── model_trian_test.py ├── SNN_Joint_learning.py └── SNN_encoder.py └── README.md /image/1.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Wavelet transform example chart/1.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /image/SNN_net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/image/SNN_net.png -------------------------------------------------------------------------------- /image/heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/image/heatmap.png -------------------------------------------------------------------------------- /image/waveform.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/image/waveform.png -------------------------------------------------------------------------------- /image/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/image/workflow.png -------------------------------------------------------------------------------- /image/Wavelet Transform.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/image/Wavelet Transform.png -------------------------------------------------------------------------------- /Wavelet transform example chart/line_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/Wavelet transform example chart/line_1.png -------------------------------------------------------------------------------- /Wavelet transform example chart/line_1017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/Wavelet transform example chart/line_1017.png -------------------------------------------------------------------------------- /Wavelet transform example chart/line_99.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/Wavelet transform example chart/line_99.png -------------------------------------------------------------------------------- /code/MTSA-SNN_dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from PIL import Image 5 | import pandas as pd 6 | from torch.utils.data import Dataset,DataLoader 7 | from torchvision.datasets import ImageFolder 8 | from torchvision import transforms,datasets 9 | import time 10 | 11 | data_transforms = { 12 | 'train': 13 | transforms.Compose([ 14 | transforms.RandomResizedCrop(256), 15 | transforms.RandomHorizontalFlip(), 16 | transforms.ToTensor(), 17 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 18 | std=[0.229, 0.224, 0.225]) 19 | ]), 20 | 'valid': 21 | transforms.Compose([ 22 | transforms.RandomResizedCrop(256), 23 | transforms.RandomHorizontalFlip(), 24 | transforms.ToTensor(), 25 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 26 | std=[0.229, 0.224, 0.225]) 27 | ]), 28 | } 29 | 30 | class Dataset(Dataset): 31 | def __init__(self,data_file,image_file, transform=None): 32 | self.df = pd.read_csv(data_file) 33 | self.transform = transform 34 | self.image_file = image_file 35 | self.img = [i[0] for i in datasets.ImageFolder(self.image_file)] 36 | self.label = [self.df.values[i][-1] for i in range(len(self.df.values))] 37 | self.data = [self.df.values[i][:-1] for i in range(len(self.df.values))] 38 | def __len__(self): 39 | return len(self.data) 40 | def __getitem__(self, idx): 41 | img =self.img[idx] 42 | label = self.label[idx] 43 | data= self.data[idx] 44 | if self.transform : 45 | image =self.transform(img) 46 | image = torch.from_numpy(np.array(image)) 47 | label = torch.from_numpy(np.array(label)) 48 | data = torch.from_numpy(np.array(data)) 49 | 50 | tensor_fea = torch.tensor(data, dtype=torch.float16).detach() 51 | tensor_lab = torch.tensor(label).detach() 52 | tensor_img = torch.tensor(image, dtype=torch.float16).detach() 53 | 54 | return tensor_img, tensor_fea, tensor_lab 55 | 56 | dataset = Dataset("/dataset/...csv", 'dataset/image/', transform=data_transforms['train']) 57 | 58 | train_data = int(len(dataset) * 0.8) 59 | test_data = int(len(dataset)) - train_data 60 | 61 | 62 | train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_data, test_data]) 63 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) 64 | test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True) -------------------------------------------------------------------------------- /code/model_trian_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from tqdm import tqdm 6 | from torch.cuda.amp import autocast as autocast 7 | from sklearn.metrics import accuracy_score, f1_score 8 | from snntorch import functional as SF 9 | scaler = None 10 | import math 11 | 12 | torch.pi = math.pi 13 | Epoch = 100 14 | batch_size = 4 15 | num_steps = 20 16 | 17 | def train_model(net, train_loader, optimizer, device): 18 | start_time = time.time() 19 | net.train() 20 | train_loss = 0 21 | train_acc = 0 22 | snn_out = [] 23 | 24 | for x1, x2, label in tqdm(train_loader): 25 | optimizer.zero_grad() 26 | x1, x2, label = x1.to(device), x2.to(device), label.to(device) 27 | 28 | with autocast(): 29 | for step in range(num_steps): 30 | 31 | output = net(x1, x2) 32 | output = output.repeat(batch_size, 1) 33 | snn_out.append(output) 34 | 35 | 36 | 37 | 38 | loss = F.cross_entropy(output.float(), label.long()) 39 | loss.backward(retain_graph=True) 40 | 41 | optimizer.step() 42 | optimizer.zero_grad() 43 | 44 | with torch.no_grad(): 45 | train_loss += loss.item() 46 | output_argmax = output.argmax(1).cpu().detach().numpy() 47 | label = label.cpu().detach().numpy() 48 | 49 | train_acc += accuracy_score(output_argmax, label) 50 | train_f1 = f1_score(label, output_argmax, average='micro') 51 | 52 | train_time = time.time() - start_time 53 | train_loss /= len(train_loader) 54 | train_acc /= len(train_loader) 55 | 56 | return train_loss, train_acc, train_f1, train_time 57 | 58 | 59 | def test_model(net, test_loader, optimizer, device): 60 | start_time = time.time() 61 | net.eval() 62 | 63 | test_loss = 0 64 | test_acc = 0 65 | 66 | for x1, x2, label in tqdm(test_loader): 67 | optimizer.zero_grad() 68 | x1, x2, label = x1.to(device), x2.to(device), label.to(device) 69 | 70 | 71 | 72 | with autocast(): 73 | output = net(x1, x2) 74 | output = output.repeat(batch_size, 1) 75 | 76 | loss = F.cross_entropy(output.float(), label.long()) 77 | loss.backward(retain_graph=True) 78 | 79 | optimizer.step() 80 | optimizer.zero_grad() 81 | 82 | with torch.no_grad(): 83 | test_loss += loss.item() 84 | output_argmax = output.argmax(1).cpu().detach().numpy() 85 | label = label.cpu().detach().numpy() 86 | 87 | test_acc += accuracy_score(output_argmax, label) 88 | test_f1 = f1_score(label, output_argmax, average='micro') 89 | 90 | test_time = time.time() - start_time 91 | test_loss /= len(test_loader) 92 | test_acc /= len(test_loader) 93 | 94 | return test_loss, test_acc, test_f1, test_time 95 | 96 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MTSA-SNN: A Multi-modal Time Series Analysis Model Based on Spiking Neural Network 2 | Time series analysis and modelling constitute a 3 | crucial research area. However, traditional artificial neural net- 4 | works often encounter challenges when dealing with complex, 5 | non-stationary time series data, such as high computational 6 | complexity, limited ability to capture temporal information, 7 | and difficulty in handling event-driven data. To address these 8 | challenges, we propose a Multi-modal Time Series Analysis 9 | Model Based on Spiking Neural Network (MTSA-SNN). The 10 | Pulse Encoder unifies the encoding of temporal images and 11 | sequential information in a common pulse-based representation. 12 | The Joint Learning Module employs a joint learning function 13 | and weight allocation mechanism to fuse information from multi- 14 | modal pulse signals complementary. Additionally, we incorporate 15 | wavelet transform operations to enhance the model’s ability to 16 | analyze and evaluate temporal information. Experimental results 17 | demonstrate that our method achieved superior performance on 18 | three complex time-series tasks. This work provides an effective 19 | event-driven approach to overcome the challenges associated with 20 | analyzing intricate temporal information 21 | ## Requirements 22 | 23 | - [PyTorch](https://pytorch.org/) >= 1.10.1 24 | - [Python](https://www.python.org/) >= 3.7 25 | - [Einops](https://github.com/arogozhnikov/einops) = 0.6.1 26 | - [NumPy](https://numpy.org/) = 1.24.3 27 | - [TorchVision](https://pytorch.org/vision/stable/transforms.html) = 0.9.1+cu111 28 | - [scikit-learn](https://scikit-learn.org/stable/index.html) = 1.2.2 29 | - [CUDA](https://developer.nvidia.com/cuda-toolkit) >= 11.3 30 | 31 | # MTSA_SNN Overall Model 32 | ![MTSA_SNN Overall Model](https://github.com/Chenngzz/MTSA-SNN/blob/main/image/SNN_net.png) 33 | 34 | # Model Structure 35 | 36 | 37 | # Wavelet Transform 38 | MTSP-SNN employs wavelet transform to decompose input 39 | signals into four subbands: LL, LH, HH and HL, which 40 | represent distinct signal characteristics in terms of different 41 | frequencies and spatial scale. Specifically, the LL subband contains the low-frequency components of the signal. In contrast, the LH and HH subbands capture the high-frequency components of both low and high-frequency signals, respectively, corresponding to different signal frequencies. The HL subband contains the low-frequency components of high-frequency signals. 42 | 43 | 44 | 45 | # The visualization of training and ablation experiments. 46 |

47 | 猫图片 48 | 狗图片 49 |

50 | 51 | # Training and evaluation 52 | Run model_train_test.py & MTSA-SNN_dataloader.py 53 | 54 | 55 | ## Citation 56 | ``` 57 | @misc{liu2024mtsasnnmultimodaltimeseries, 58 | title={MTSA-SNN: A Multi-modal Time Series Analysis Model Based on Spiking Neural Network}, 59 | author={Chengzhi Liu and Zheng Tao and Zihong Luo and Chenghao Liu}, 60 | year={2024}, 61 | eprint={2402.05423}, 62 | archivePrefix={arXiv}, 63 | primaryClass={cs.CV}, 64 | url={https://arxiv.org/abs/2402.05423}, 65 | } 66 | ``` 67 | 68 | -------------------------------------------------------------------------------- /code/SNN_Joint_learning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import snntorch as snn 3 | from torch import einsum, nn 4 | from snntorch import surrogate 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat 7 | from SNN_encoder import SNN_img,SNN_series, SNN_pulse_supplementation 8 | 9 | batch_size = 1 10 | num_inputs = 260 11 | num_inputs_img = 256 12 | 13 | num_hidden = 5 14 | num_outputs = 10 15 | num_steps = batch_size 16 | beta = 0.95 17 | spike_grad = surrogate.fast_sigmoid(slope=25) 18 | 19 | 20 | 21 | 22 | 23 | class LayerNorm(nn.Module): 24 | def __init__(self, dim): 25 | super().__init__() 26 | self.gamma = nn.Parameter(torch.ones(dim)) 27 | self.beta = nn.Parameter(torch.zeros(dim)) 28 | 29 | def forward(self, x): 30 | gamma = self.gamma[:x.shape[-1]] 31 | beta = self.beta[:x.shape[-1]] 32 | return F.layer_norm(x, x.shape[1:], gamma.repeat(x.shape[1], 1), beta.repeat(x.shape[1], 1)) 33 | 34 | 35 | 36 | 37 | class SNN_joint_learning_module(nn.Module): 38 | def __init__(self, dim, dim_out, heads, dim_head): 39 | super().__init__() 40 | self.heads = heads 41 | self.scale = dim_head ** -0.5 42 | self.norm = LayerNorm(dim) 43 | 44 | inner_dim = heads * dim_head 45 | 46 | self.snn_text = snn.Leaky(beta=beta, spike_grad=spike_grad) 47 | self.snn_img = snn.Leaky(beta=beta, spike_grad=spike_grad) 48 | self.snn_out = snn.Leaky(beta=beta, spike_grad=spike_grad) 49 | 50 | self.identity_pic = nn.Linear(dim, inner_dim) 51 | self.identity_context = nn.Linear(dim, dim_head * 2) 52 | 53 | self.resnet = nn.Sequential(nn.Conv1d(512, 32, 1, 1), 54 | nn.Conv1d(32, batch_size, 1, 1)) 55 | 56 | self.relu = nn.ReLU() 57 | self.fc = nn.Linear(512, 512) 58 | 59 | def forward(self, pic, context): 60 | image = [] 61 | text = [] 62 | 63 | 64 | mem1 = self.snn_img.init_leaky() 65 | mem2 = self.snn_text.init_leaky() 66 | 67 | 68 | pic = self.norm(pic) 69 | context = self.norm(context) 70 | 71 | 72 | 73 | 74 | for step in range(num_steps): 75 | identity_pic, mem1 = self.snn_img(pic, mem1) 76 | image.append(identity_pic) 77 | identity_pic = torch.stack(image, dim=0) 78 | 79 | 80 | 81 | identity_pic = self.identity_pic(identity_pic) 82 | identity_pic = identity_pic * self.scale 83 | 84 | 85 | 86 | for step in range(num_steps): 87 | identity_context, mem2 = self.snn_text(context, mem2) 88 | text.append(identity_context) 89 | identity_context = torch.stack(text, dim=0) 90 | 91 | 92 | identity_context, Value = self.identity_context(identity_context).chunk(2, dim=-1) 93 | 94 | 95 | 96 | if identity_context.ndim == 2 or Value.ndim == 2: 97 | identity_context = repeat(identity_context, 'h w -> h w c', c=64) 98 | Value = repeat(Value, 'h w -> h w c', c=64) 99 | 100 | 101 | sim = einsum('b h i d, b h i j -> b h i d', identity_pic, identity_context) 102 | sim = sim - sim.amax(dim=-1, keepdim=True) 103 | attn = sim.softmax(dim=-1) 104 | 105 | 106 | out = einsum('b h i j, b h i d -> b h i j', attn, Value) 107 | out = rearrange(out, 'b h n d -> b n (h d)') 108 | 109 | 110 | 111 | identity_pic = torch.sum(identity_pic, dim=3) 112 | out_1 = self.resnet(out) 113 | out_1 = self.fc(out_1) 114 | out_1 = out_1 + identity_pic 115 | 116 | 117 | out_1 = torch.transpose(out_1, 1, 2) 118 | out_1 = F.interpolate(out_1, size=(64,), mode='linear', align_corners=False) 119 | 120 | 121 | 122 | identity_context = torch.sum(identity_context, dim=1) 123 | out_2 = self.relu(out_1) + identity_context 124 | out_2 = self.relu(out_2) 125 | 126 | return out_2 127 | 128 | 129 | class Fusion(nn.Module): 130 | def __init__(self): 131 | super().__init__() 132 | 133 | self.img_net = SNN_img() 134 | self.img_enhance = SNN_pulse_supplementation() 135 | self.text_net = SNN_series() 136 | self.output_layer = nn.Sequential(nn.Flatten(), 137 | nn.Linear(32768, 1024), 138 | nn.Linear(1024, 5), 139 | nn.ReLU() 140 | ) 141 | 142 | 143 | self.fusion_1 = SNN_joint_learning_module(dim=512, dim_out=512, heads=8, dim_head=64) 144 | 145 | self.linear_img = nn.Sequential(nn.Linear(4096, 512), nn.ReLU()) 146 | self.linear_series = nn.Sequential(nn.Linear(10, 512), nn.ReLU()) 147 | self.linear_enhance = nn.Sequential(nn.Linear(512, 5), nn.ReLU()) 148 | 149 | def forward(self, img, series): 150 | series_embed = self.text_net(series) 151 | series_embed = series_embed.reshape(batch_size, -1) 152 | series_embed = self.linear_series(series_embed) 153 | series_embed = rearrange(series_embed, 'b n -> b n 1') 154 | series_embed = series_embed.repeat(1, 1, 512) 155 | 156 | 157 | img_embed = self.img_net(img) 158 | img_embed = img_embed.reshape(batch_size, 512, -1) 159 | img_embed = self.linear_img(img_embed) 160 | 161 | img_enhance = self.img_enhance(img) 162 | img_enhance = self.linear_enhance(img_enhance) 163 | 164 | 165 | out = self.fusion_1(img_embed, series_embed) 166 | out = self.output_layer(out) 167 | 168 | 169 | out = out + img_enhance 170 | 171 | return out 172 | 173 | 174 | net_snn = Fusion() 175 | img = torch.randn(batch_size, 3, 256, 256) 176 | x = torch.randn(batch_size, 260) 177 | 178 | output = net_snn(img, x) 179 | print(output.shape) 180 | 181 | 182 | -------------------------------------------------------------------------------- /code/SNN_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import snntorch as snn 3 | from torch import einsum, nn 4 | from snntorch import surrogate 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat 7 | from spikingjelly.clock_driven.neuron import MultiStepLIFNode 8 | 9 | batch_size = 1 10 | num_inputs = 260 11 | num_inputs_img = 256 12 | num_hidden = 5 13 | num_outputs = 10 14 | num_steps = batch_size 15 | beta = 0.95 16 | spike_grad = surrogate.fast_sigmoid(slope=25) 17 | 18 | 19 | # SNN_series encoder 20 | class SNN_series(nn.Module): 21 | def __init__(self): 22 | super().__init__() 23 | self.fc1 = nn.Linear(num_inputs, num_hidden) 24 | self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad) 25 | self.fc2 = nn.Linear(num_hidden, num_outputs) 26 | self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad) 27 | 28 | def forward(self, x): 29 | mem1 = self.lif1.init_leaky() 30 | mem2 = self.lif2.init_leaky() 31 | spk2_rec = [] 32 | mem2_rec = [] 33 | 34 | for step in range(num_steps): 35 | cur1 = self.fc1(x) 36 | spk1, mem1 = self.lif1(cur1, mem1) 37 | cur2 = self.fc2(spk1) 38 | spk2, mem2 = self.lif2(cur2, mem2) 39 | spk2_rec.append(spk2) 40 | mem2_rec.append(mem2) 41 | output = torch.stack(spk2_rec, dim=0) 42 | 43 | return output 44 | 45 | 46 | 47 | # SNN_series encoder 48 | class SNN_img(nn.Module): 49 | def __init__(self): 50 | super().__init__() 51 | 52 | # Initialize layers 53 | self.conv1 = nn.Conv2d(3, 12, 1) 54 | self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad) 55 | self.conv2 = nn.Conv2d(12, 32, 1) 56 | self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad) 57 | 58 | self.fc1 = nn.Linear(2097152, 512) 59 | self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad) 60 | 61 | self.conv3 = nn.Conv2d(3, 32, 1) 62 | 63 | def forward(self, x): 64 | img = [] 65 | mem1 = self.lif1.init_leaky() 66 | mem2 = self.lif2.init_leaky() 67 | 68 | for step in range(num_steps): 69 | cur1 = self.conv1(x) 70 | cur1 = F.max_pool2d(cur1, 1) 71 | spk1, mem1 = self.lif1(cur1, mem1) 72 | 73 | cur2 = F.max_pool2d(self.conv2(spk1), 1) 74 | out_resnet, mem2 = self.lif2(cur2, mem2) 75 | 76 | 77 | img.append(out_resnet) 78 | output = torch.stack(img, dim=0) 79 | 80 | output_conv = self.conv3(x) 81 | first_img = rearrange(output_conv, 'b n j d-> 1 b n j d') 82 | 83 | output = first_img + output 84 | 85 | return output 86 | 87 | # SNN transformer 88 | class Pulse_extraction(nn.Module): 89 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 90 | super().__init__() 91 | 92 | self.dim = dim 93 | self.num_heads = num_heads 94 | self.scale = 0.125 95 | self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1,bias=False) 96 | self.q_bn = nn.BatchNorm1d(dim) 97 | self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True) 98 | 99 | self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1,bias=False) 100 | self.k_bn = nn.BatchNorm1d(dim) 101 | self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True) 102 | 103 | self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1,bias=False) 104 | self.v_bn = nn.BatchNorm1d(dim) 105 | self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True) 106 | self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True) 107 | 108 | self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1) 109 | self.proj_bn = nn.BatchNorm1d(dim) 110 | self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True) 111 | 112 | def forward(self, x, res_attn): 113 | 114 | T,B,C,H,W = x.shape 115 | x = x.flatten(3) 116 | T, B, C, N = x.shape 117 | x_for_qkv = x.flatten(0, 1) 118 | q_conv_out = self.q_conv(x_for_qkv) 119 | q_conv_out = self.q_bn(q_conv_out).reshape(T,B,C,N).contiguous() 120 | q_conv_out = self.q_lif(q_conv_out) 121 | q=q_conv_out 122 | 123 | 124 | 125 | k_conv_out = self.k_conv(x_for_qkv) 126 | k_conv_out = self.k_bn(k_conv_out).reshape(T,B,C,N).contiguous() 127 | k_conv_out = self.k_lif(k_conv_out) 128 | k=k_conv_out 129 | 130 | 131 | v_conv_out = self.v_conv(x_for_qkv) 132 | v_conv_out = self.v_bn(v_conv_out).reshape(T,B,C,N).contiguous() 133 | v_conv_out = self.v_lif(v_conv_out) 134 | v=v_conv_out 135 | 136 | 137 | 138 | x = k.transpose(-2,-1) @ v 139 | x = (q @ x) * self.scale 140 | x = self.attn_lif(x) 141 | x = x.flatten(0,1) 142 | x = self.proj_lif(self.proj_bn(self.proj_conv(x)).reshape(T,B,C,H,W)) 143 | x = self.proj_lif(x.reshape(T, B, C, H, W)) 144 | return x, v 145 | 146 | 147 | 148 | class Block(nn.Module): 149 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., sr_ratio=1): 150 | super().__init__() 151 | 152 | self.attn = Pulse_extraction(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 153 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 154 | 155 | mlp_hidden_dim = int(dim * mlp_ratio) 156 | 157 | 158 | def forward(self, x, res_attn): 159 | x_attn, attn = (self.attn(x, res_attn)) 160 | x = x + x_attn 161 | return x, attn 162 | 163 | 164 | class Pulse_transformation(nn.Module): 165 | def __init__(self): 166 | super().__init__() 167 | 168 | 169 | self.fc1 = nn.Linear(num_inputs_img, num_hidden) 170 | self.lif1 = snn.Leaky(beta=beta) 171 | self.fc2 = nn.Linear(num_hidden, num_outputs) 172 | self.lif2 = snn.Leaky(beta=beta) 173 | 174 | def forward(self, x): 175 | mem1 = self.lif1.init_leaky() 176 | mem2 = self.lif2.init_leaky() 177 | 178 | # Record the final layer 179 | spk2_rec = [] 180 | mem2_rec = [] 181 | 182 | for step in range(num_steps): 183 | cur1 = self.fc1(x) 184 | spk1, mem1 = self.lif1(cur1, mem1) 185 | cur2 = self.fc2(spk1) 186 | spk2, mem2 = self.lif2(cur2, mem2) 187 | spk2_rec.append(spk2) 188 | mem2_rec.append(mem2) 189 | output = torch.stack(spk2_rec, dim=0) 190 | return output 191 | 192 | 193 | class SNN_pulse_supplementation(nn.Module): 194 | def __init__(self, embed_dims=3, num_heads=4, mlp_ratios= 4, qkv_bias=False, qk_scale=None, 195 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, sr_ratios=[8, 4, 2]): 196 | super().__init__() 197 | 198 | self.patch_embed = Pulse_transformation() 199 | 200 | self.block = nn.ModuleList([Block( 201 | dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias, 202 | qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, sr_ratio=sr_ratios)]) 203 | 204 | 205 | self.linear=nn.Linear(3,512) 206 | 207 | 208 | def forward_features(self, x): 209 | output= self.patch_embed(x) 210 | attn = None 211 | for blk in self.block: 212 | x, attn = blk(output, attn) 213 | return x.flatten(3).mean(3) 214 | 215 | def forward(self, x): 216 | x = self.forward_features(x) 217 | final_x = x[0, :, :] 218 | final_x=self.linear(final_x) 219 | return final_x 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | --------------------------------------------------------------------------------