├── image
├── 1.txt
├── SNN_net.png
├── heatmap.png
├── waveform.png
├── workflow.png
└── Wavelet Transform.png
├── Wavelet transform example chart
├── 1.txt
├── line_1.png
├── line_1017.png
└── line_99.png
├── code
├── MTSA-SNN_dataloader.py
├── model_trian_test.py
├── SNN_Joint_learning.py
└── SNN_encoder.py
└── README.md
/image/1.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/Wavelet transform example chart/1.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/image/SNN_net.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/image/SNN_net.png
--------------------------------------------------------------------------------
/image/heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/image/heatmap.png
--------------------------------------------------------------------------------
/image/waveform.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/image/waveform.png
--------------------------------------------------------------------------------
/image/workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/image/workflow.png
--------------------------------------------------------------------------------
/image/Wavelet Transform.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/image/Wavelet Transform.png
--------------------------------------------------------------------------------
/Wavelet transform example chart/line_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/Wavelet transform example chart/line_1.png
--------------------------------------------------------------------------------
/Wavelet transform example chart/line_1017.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/Wavelet transform example chart/line_1017.png
--------------------------------------------------------------------------------
/Wavelet transform example chart/line_99.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenngzz/MTSA-SNN/HEAD/Wavelet transform example chart/line_99.png
--------------------------------------------------------------------------------
/code/MTSA-SNN_dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import os
4 | from PIL import Image
5 | import pandas as pd
6 | from torch.utils.data import Dataset,DataLoader
7 | from torchvision.datasets import ImageFolder
8 | from torchvision import transforms,datasets
9 | import time
10 |
11 | data_transforms = {
12 | 'train':
13 | transforms.Compose([
14 | transforms.RandomResizedCrop(256),
15 | transforms.RandomHorizontalFlip(),
16 | transforms.ToTensor(),
17 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
18 | std=[0.229, 0.224, 0.225])
19 | ]),
20 | 'valid':
21 | transforms.Compose([
22 | transforms.RandomResizedCrop(256),
23 | transforms.RandomHorizontalFlip(),
24 | transforms.ToTensor(),
25 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
26 | std=[0.229, 0.224, 0.225])
27 | ]),
28 | }
29 |
30 | class Dataset(Dataset):
31 | def __init__(self,data_file,image_file, transform=None):
32 | self.df = pd.read_csv(data_file)
33 | self.transform = transform
34 | self.image_file = image_file
35 | self.img = [i[0] for i in datasets.ImageFolder(self.image_file)]
36 | self.label = [self.df.values[i][-1] for i in range(len(self.df.values))]
37 | self.data = [self.df.values[i][:-1] for i in range(len(self.df.values))]
38 | def __len__(self):
39 | return len(self.data)
40 | def __getitem__(self, idx):
41 | img =self.img[idx]
42 | label = self.label[idx]
43 | data= self.data[idx]
44 | if self.transform :
45 | image =self.transform(img)
46 | image = torch.from_numpy(np.array(image))
47 | label = torch.from_numpy(np.array(label))
48 | data = torch.from_numpy(np.array(data))
49 |
50 | tensor_fea = torch.tensor(data, dtype=torch.float16).detach()
51 | tensor_lab = torch.tensor(label).detach()
52 | tensor_img = torch.tensor(image, dtype=torch.float16).detach()
53 |
54 | return tensor_img, tensor_fea, tensor_lab
55 |
56 | dataset = Dataset("/dataset/...csv", 'dataset/image/', transform=data_transforms['train'])
57 |
58 | train_data = int(len(dataset) * 0.8)
59 | test_data = int(len(dataset)) - train_data
60 |
61 |
62 | train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_data, test_data])
63 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
64 | test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
--------------------------------------------------------------------------------
/code/model_trian_test.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from tqdm import tqdm
6 | from torch.cuda.amp import autocast as autocast
7 | from sklearn.metrics import accuracy_score, f1_score
8 | from snntorch import functional as SF
9 | scaler = None
10 | import math
11 |
12 | torch.pi = math.pi
13 | Epoch = 100
14 | batch_size = 4
15 | num_steps = 20
16 |
17 | def train_model(net, train_loader, optimizer, device):
18 | start_time = time.time()
19 | net.train()
20 | train_loss = 0
21 | train_acc = 0
22 | snn_out = []
23 |
24 | for x1, x2, label in tqdm(train_loader):
25 | optimizer.zero_grad()
26 | x1, x2, label = x1.to(device), x2.to(device), label.to(device)
27 |
28 | with autocast():
29 | for step in range(num_steps):
30 |
31 | output = net(x1, x2)
32 | output = output.repeat(batch_size, 1)
33 | snn_out.append(output)
34 |
35 |
36 |
37 |
38 | loss = F.cross_entropy(output.float(), label.long())
39 | loss.backward(retain_graph=True)
40 |
41 | optimizer.step()
42 | optimizer.zero_grad()
43 |
44 | with torch.no_grad():
45 | train_loss += loss.item()
46 | output_argmax = output.argmax(1).cpu().detach().numpy()
47 | label = label.cpu().detach().numpy()
48 |
49 | train_acc += accuracy_score(output_argmax, label)
50 | train_f1 = f1_score(label, output_argmax, average='micro')
51 |
52 | train_time = time.time() - start_time
53 | train_loss /= len(train_loader)
54 | train_acc /= len(train_loader)
55 |
56 | return train_loss, train_acc, train_f1, train_time
57 |
58 |
59 | def test_model(net, test_loader, optimizer, device):
60 | start_time = time.time()
61 | net.eval()
62 |
63 | test_loss = 0
64 | test_acc = 0
65 |
66 | for x1, x2, label in tqdm(test_loader):
67 | optimizer.zero_grad()
68 | x1, x2, label = x1.to(device), x2.to(device), label.to(device)
69 |
70 |
71 |
72 | with autocast():
73 | output = net(x1, x2)
74 | output = output.repeat(batch_size, 1)
75 |
76 | loss = F.cross_entropy(output.float(), label.long())
77 | loss.backward(retain_graph=True)
78 |
79 | optimizer.step()
80 | optimizer.zero_grad()
81 |
82 | with torch.no_grad():
83 | test_loss += loss.item()
84 | output_argmax = output.argmax(1).cpu().detach().numpy()
85 | label = label.cpu().detach().numpy()
86 |
87 | test_acc += accuracy_score(output_argmax, label)
88 | test_f1 = f1_score(label, output_argmax, average='micro')
89 |
90 | test_time = time.time() - start_time
91 | test_loss /= len(test_loader)
92 | test_acc /= len(test_loader)
93 |
94 | return test_loss, test_acc, test_f1, test_time
95 |
96 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MTSA-SNN: A Multi-modal Time Series Analysis Model Based on Spiking Neural Network
2 | Time series analysis and modelling constitute a
3 | crucial research area. However, traditional artificial neural net-
4 | works often encounter challenges when dealing with complex,
5 | non-stationary time series data, such as high computational
6 | complexity, limited ability to capture temporal information,
7 | and difficulty in handling event-driven data. To address these
8 | challenges, we propose a Multi-modal Time Series Analysis
9 | Model Based on Spiking Neural Network (MTSA-SNN). The
10 | Pulse Encoder unifies the encoding of temporal images and
11 | sequential information in a common pulse-based representation.
12 | The Joint Learning Module employs a joint learning function
13 | and weight allocation mechanism to fuse information from multi-
14 | modal pulse signals complementary. Additionally, we incorporate
15 | wavelet transform operations to enhance the model’s ability to
16 | analyze and evaluate temporal information. Experimental results
17 | demonstrate that our method achieved superior performance on
18 | three complex time-series tasks. This work provides an effective
19 | event-driven approach to overcome the challenges associated with
20 | analyzing intricate temporal information
21 | ## Requirements
22 |
23 | - [PyTorch](https://pytorch.org/) >= 1.10.1
24 | - [Python](https://www.python.org/) >= 3.7
25 | - [Einops](https://github.com/arogozhnikov/einops) = 0.6.1
26 | - [NumPy](https://numpy.org/) = 1.24.3
27 | - [TorchVision](https://pytorch.org/vision/stable/transforms.html) = 0.9.1+cu111
28 | - [scikit-learn](https://scikit-learn.org/stable/index.html) = 1.2.2
29 | - [CUDA](https://developer.nvidia.com/cuda-toolkit) >= 11.3
30 |
31 | # MTSA_SNN Overall Model
32 | 
33 |
34 | # Model Structure
35 |
36 |
37 | # Wavelet Transform
38 | MTSP-SNN employs wavelet transform to decompose input
39 | signals into four subbands: LL, LH, HH and HL, which
40 | represent distinct signal characteristics in terms of different
41 | frequencies and spatial scale. Specifically, the LL subband contains the low-frequency components of the signal. In contrast, the LH and HH subbands capture the high-frequency components of both low and high-frequency signals, respectively, corresponding to different signal frequencies. The HL subband contains the low-frequency components of high-frequency signals.
42 |
43 |
44 |
45 | # The visualization of training and ablation experiments.
46 |
47 |
48 |
49 |