├── Gated Transformer 论文IJCAI版
├── .idea
│ ├── .gitignore
│ ├── ADL_Transformer.iml
│ ├── inspectionProfiles
│ │ └── profiles_settings.xml
│ ├── misc.xml
│ ├── modules.xml
│ └── vcs.xml
├── Gated Transformer 架构图.png
├── dataset_process
│ ├── __pycache__
│ │ └── dataset_process.cpython-37.pyc
│ └── dataset_process.py
├── font
│ └── simsun.ttc
├── gather_figure
│ ├── AUSLAN
│ │ ├── AUSLAN channel_gather.jpg
│ │ └── AUSLAN input_gather.jpg
│ ├── ECG
│ │ ├── ECG all_sample_gather.jpg
│ │ ├── ECG channel_gather.jpg
│ │ └── ECG input_gather.jpg
│ └── JapaneseVowels
│ │ ├── JapaneseVowels Sample_gather.jpg
│ │ ├── JapaneseVowels channel_gather.jpg
│ │ └── JapaneseVowels input_gather.jpg
├── heatmap_figure_in_test
│ └── JapaneseVowels
│ │ └── JapaneseVowels accuracy=98.11 0.jpg
├── module
│ ├── __pycache__
│ │ ├── encoder.cpython-37.pyc
│ │ ├── feedForward.cpython-37.pyc
│ │ ├── loss.cpython-37.pyc
│ │ ├── multiHeadAttention.cpython-37.pyc
│ │ └── transformer.cpython-37.pyc
│ ├── encoder.py
│ ├── feedForward.py
│ ├── loss.py
│ ├── multiHeadAttention.py
│ └── transformer.py
├── mytest
│ ├── DTW_test.py
│ ├── HeatMap.py
│ ├── HeatMap_DTW.py
│ ├── __pycache__
│ │ └── HeatMap_DTW.cpython-37.pyc
│ └── kmeans_test.py
├── result_figure
│ ├── CharacterTrajectories 96.6% Adam epoch=100 batch=4 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png
│ ├── ECG 86.0% Adam epoch=10 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png
│ ├── ECG 86.0% Adam epoch=20 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png
│ ├── ECG 87.0% Adam epoch=50 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png
│ └── JapaneseVowels 98.11% Adagrad epoch=100 batch=3 lr=0.0001 pe=True mask=True [512,8,8,8,8,0.2].png
├── run.py
├── run_with_saved_model.py
└── utils
│ ├── TSNE.py
│ ├── __pycache__
│ ├── TSNE.cpython-37.pyc
│ ├── colorful_line.cpython-37.pyc
│ ├── draw_line.cpython-37.pyc
│ ├── heatMap.cpython-37.pyc
│ ├── random_seed.cpython-37.pyc
│ └── visualization.cpython-37.pyc
│ ├── colorful_line.py
│ ├── draw_line.py
│ ├── heatMap.py
│ ├── kmeans.py
│ ├── random_seed.py
│ └── visualization.py
└── README.md
/Gated Transformer 论文IJCAI版/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/.idea/ADL_Transformer.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/Gated Transformer 架构图.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/Gated Transformer 架构图.png
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/dataset_process/__pycache__/dataset_process.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/dataset_process/__pycache__/dataset_process.cpython-37.pyc
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/dataset_process/dataset_process.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import torch
3 |
4 | from scipy.io import loadmat
5 |
6 |
7 | class MyDataset(Dataset):
8 | def __init__(self,
9 | path: str,
10 | dataset: str):
11 | """
12 | 训练数据集与测试数据集的Dataset对象
13 | :param path: 数据集路径
14 | :param dataset: 区分是获得训练集还是测试集
15 | """
16 | super(MyDataset, self).__init__()
17 | self.dataset = dataset # 选择获取测试集还是训练集
18 | self.train_len, \
19 | self.test_len, \
20 | self.input_len, \
21 | self.channel_len, \
22 | self.output_len, \
23 | self.train_dataset, \
24 | self.train_label, \
25 | self.test_dataset, \
26 | self.test_label, \
27 | self.max_length_sample_inTest, \
28 | self.train_dataset_with_no_paddding = self.pre_option(path)
29 |
30 | def __getitem__(self, index):
31 | if self.dataset == 'train':
32 | return self.train_dataset[index], self.train_label[index] - 1
33 | elif self.dataset == 'test':
34 | return self.test_dataset[index], self.test_label[index] - 1
35 |
36 | def __len__(self):
37 | if self.dataset == 'train':
38 | return self.train_len
39 | elif self.dataset == 'test':
40 | return self.test_len
41 |
42 | # 数据预处理
43 | def pre_option(self, path: str):
44 | """
45 | 数据预处理 由于每个样本的时间步维度不同,在此使用最长的时间步作为时间步的维度,使用0进行填充
46 | :param path: 数据集路径
47 | :return: 训练集样本数量,测试集样本数量,时间步维度,通道数,分类数,训练集数据,训练集标签,测试集数据,测试集标签,测试集中时间步最长的样本列表,没有padding的训练集数据
48 | """
49 | m = loadmat(path)
50 |
51 | # m中是一个字典 有4个key 其中最后一个键值对存储的是数据
52 | x1, x2, x3, x4 = m
53 | data = m[x4]
54 |
55 | data00 = data[0][0]
56 | # print('data00.shape', data00.shape) # () data00才到达数据的维度
57 |
58 | index_train = str(data.dtype).find('train\'')
59 | index_trainlabels = str(data.dtype).find('trainlabels')
60 | index_test = str(data.dtype).find('test\'')
61 | index_testlabels = str(data.dtype).find('testlabels')
62 | list = [index_test, index_train, index_testlabels, index_trainlabels]
63 | list = sorted(list)
64 | index_train = list.index(index_train)
65 | index_trainlabels = list.index(index_trainlabels)
66 | index_test = list.index(index_test)
67 | index_testlabels = list.index(index_testlabels)
68 |
69 | # [('trainlabels', 'O'), ('train', 'O'), ('testlabels', 'O'), ('test', 'O')] O 表示数据类型为 numpy.object
70 | train_label = data00[index_trainlabels]
71 | train_data = data00[index_train]
72 | test_label = data00[index_testlabels]
73 | test_data = data00[index_test]
74 |
75 | train_label = train_label.squeeze()
76 | train_data = train_data.squeeze()
77 | test_label = test_label.squeeze()
78 | test_data = test_data.squeeze()
79 |
80 | train_len = train_data.shape[0]
81 | test_len = test_data.shape[0]
82 | output_len = len(tuple(set(train_label)))
83 |
84 | # 时间步最大值
85 | max_lenth = 0 # 93
86 | for item in train_data:
87 | item = torch.as_tensor(item).float()
88 | if item.shape[1] > max_lenth:
89 | max_lenth = item.shape[1]
90 | # max_length_index = train_data.tolist().index(item.tolist())
91 |
92 | for item in test_data:
93 | item = torch.as_tensor(item).float()
94 | if item.shape[1] > max_lenth:
95 | max_lenth = item.shape[1]
96 |
97 | # 填充Padding 使用0进行填充
98 | # train_data, test_data为numpy.object 类型,不能直接对里面的numpy.ndarray进行处理
99 | train_dataset_with_no_paddding = []
100 | test_dataset_with_no_paddding = []
101 | train_dataset = []
102 | test_dataset = []
103 | max_length_sample_inTest = []
104 | for x1 in train_data:
105 | train_dataset_with_no_paddding.append(x1.transpose(-1, -2).tolist())
106 | x1 = torch.as_tensor(x1).float()
107 | if x1.shape[1] != max_lenth:
108 | padding = torch.zeros(x1.shape[0], max_lenth - x1.shape[1])
109 | x1 = torch.cat((x1, padding), dim=1)
110 | train_dataset.append(x1)
111 |
112 | for index, x2 in enumerate(test_data):
113 | test_dataset_with_no_paddding.append(x2.transpose(-1, -2).tolist())
114 | x2 = torch.as_tensor(x2).float()
115 | if x2.shape[1] != max_lenth:
116 | padding = torch.zeros(x2.shape[0], max_lenth - x2.shape[1])
117 | x2 = torch.cat((x2, padding), dim=1)
118 | else:
119 | max_length_sample_inTest.append(x2.transpose(-1, -2))
120 | test_dataset.append(x2)
121 |
122 | # 最后维度 [数据条数,时间步数最大值,时间序列维度]
123 | # train_dataset_with_no_paddding = torch.stack(train_dataset_with_no_paddding, dim=0).permute(0, 2, 1)
124 | # test_dataset_with_no_paddding = torch.stack(test_dataset_with_no_paddding, dim=0).permute(0, 2, 1)
125 | train_dataset = torch.stack(train_dataset, dim=0).permute(0, 2, 1)
126 | test_dataset = torch.stack(test_dataset, dim=0).permute(0, 2, 1)
127 | train_label = torch.Tensor(train_label)
128 | test_label = torch.Tensor(test_label)
129 | channel = test_dataset[0].shape[-1]
130 | input = test_dataset[0].shape[-2]
131 |
132 | return train_len, test_len, input, channel, output_len, train_dataset, train_label, test_dataset, test_label, max_length_sample_inTest, train_dataset_with_no_paddding
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/font/simsun.ttc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/font/simsun.ttc
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/gather_figure/AUSLAN/AUSLAN channel_gather.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/AUSLAN/AUSLAN channel_gather.jpg
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/gather_figure/AUSLAN/AUSLAN input_gather.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/AUSLAN/AUSLAN input_gather.jpg
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/gather_figure/ECG/ECG all_sample_gather.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/ECG/ECG all_sample_gather.jpg
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/gather_figure/ECG/ECG channel_gather.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/ECG/ECG channel_gather.jpg
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/gather_figure/ECG/ECG input_gather.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/ECG/ECG input_gather.jpg
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/gather_figure/JapaneseVowels/JapaneseVowels Sample_gather.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/JapaneseVowels/JapaneseVowels Sample_gather.jpg
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/gather_figure/JapaneseVowels/JapaneseVowels channel_gather.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/JapaneseVowels/JapaneseVowels channel_gather.jpg
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/gather_figure/JapaneseVowels/JapaneseVowels input_gather.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/JapaneseVowels/JapaneseVowels input_gather.jpg
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/heatmap_figure_in_test/JapaneseVowels/JapaneseVowels accuracy=98.11 0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/heatmap_figure_in_test/JapaneseVowels/JapaneseVowels accuracy=98.11 0.jpg
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/module/__pycache__/encoder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/module/__pycache__/encoder.cpython-37.pyc
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/module/__pycache__/feedForward.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/module/__pycache__/feedForward.cpython-37.pyc
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/module/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/module/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/module/__pycache__/multiHeadAttention.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/module/__pycache__/multiHeadAttention.cpython-37.pyc
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/module/__pycache__/transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/module/__pycache__/transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/module/encoder.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module
2 | import torch
3 |
4 | from module.feedForward import FeedForward
5 | from module.multiHeadAttention import MultiHeadAttention
6 |
7 | class Encoder(Module):
8 | def __init__(self,
9 | d_model: int,
10 | d_hidden: int,
11 | q: int,
12 | v: int,
13 | h: int,
14 | device: str,
15 | mask: bool = False,
16 | dropout: float = 0.1):
17 | super(Encoder, self).__init__()
18 |
19 | self.MHA = MultiHeadAttention(d_model=d_model, q=q, v=v, h=h, mask=mask, device=device, dropout=dropout)
20 | self.feedforward = FeedForward(d_model=d_model, d_hidden=d_hidden)
21 | self.dropout = torch.nn.Dropout(p=dropout)
22 | self.layerNormal_1 = torch.nn.LayerNorm(d_model)
23 | self.layerNormal_2 = torch.nn.LayerNorm(d_model)
24 |
25 | def forward(self, x, stage):
26 |
27 | residual = x
28 | x, score = self.MHA(x, stage)
29 | x = self.dropout(x)
30 | x = self.layerNormal_1(x + residual)
31 |
32 | residual = x
33 | x = self.feedforward(x)
34 | x = self.dropout(x)
35 | x = self.layerNormal_2(x + residual)
36 |
37 | return x, score
38 |
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/module/feedForward.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | class FeedForward(Module):
6 | def __init__(self,
7 | d_model: int,
8 | d_hidden: int = 512):
9 | super(FeedForward, self).__init__()
10 |
11 | self.linear_1 = torch.nn.Linear(d_model, d_hidden)
12 | self.linear_2 = torch.nn.Linear(d_hidden, d_model)
13 |
14 | def forward(self, x):
15 |
16 | x = self.linear_1(x)
17 | x = F.relu(x)
18 | x = self.linear_2(x)
19 |
20 | return x
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/module/loss.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module
2 | import torch
3 | from torch.nn import CrossEntropyLoss
4 |
5 | class Myloss(Module):
6 | def __init__(self):
7 | super(Myloss, self).__init__()
8 | self.loss_function = CrossEntropyLoss()
9 |
10 | def forward(self, y_pre, y_true):
11 | y_true = y_true.long()
12 | loss = self.loss_function(y_pre, y_true)
13 |
14 | return loss
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/module/multiHeadAttention.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module
2 | import torch
3 | import math
4 | import torch.nn.functional as F
5 |
6 |
7 | class MultiHeadAttention(Module):
8 | def __init__(self,
9 | d_model: int,
10 | q: int,
11 | v: int,
12 | h: int,
13 | device: str,
14 | mask: bool=False,
15 | dropout: float = 0.1):
16 | super(MultiHeadAttention, self).__init__()
17 |
18 | self.W_q = torch.nn.Linear(d_model, q * h)
19 | self.W_k = torch.nn.Linear(d_model, q * h)
20 | self.W_v = torch.nn.Linear(d_model, v * h)
21 |
22 | self.W_o = torch.nn.Linear(v * h, d_model)
23 |
24 | self.device = device
25 | self._h = h
26 | self._q = q
27 |
28 | self.mask = mask
29 | self.dropout = torch.nn.Dropout(p=dropout)
30 | self.score = None
31 |
32 | def forward(self, x, stage):
33 | Q = torch.cat(self.W_q(x).chunk(self._h, dim=-1), dim=0)
34 | K = torch.cat(self.W_k(x).chunk(self._h, dim=-1), dim=0)
35 | V = torch.cat(self.W_v(x).chunk(self._h, dim=-1), dim=0)
36 |
37 | score = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self._q)
38 | self.score = score
39 |
40 | if self.mask and stage == 'train':
41 | mask = torch.ones_like(score[0])
42 | mask = torch.tril(mask, diagonal=0)
43 | score = torch.where(mask > 0, score, torch.Tensor([-2**32+1]).expand_as(score[0]).to(self.device))
44 |
45 | score = F.softmax(score, dim=-1)
46 |
47 | attention = torch.matmul(score, V)
48 |
49 | attention_heads = torch.cat(attention.chunk(self._h, dim=0), dim=-1)
50 |
51 | self_attention = self.W_o(attention_heads)
52 |
53 | return self_attention, self.score
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/module/transformer.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module
2 | import torch
3 | from torch.nn import ModuleList
4 | from module.encoder import Encoder
5 | import math
6 | import torch.nn.functional as F
7 |
8 |
9 | class Transformer(Module):
10 | def __init__(self,
11 | d_model: int,
12 | d_input: int,
13 | d_channel: int,
14 | d_output: int,
15 | d_hidden: int,
16 | q: int,
17 | v: int,
18 | h: int,
19 | N: int,
20 | device: str,
21 | dropout: float = 0.1,
22 | pe: bool = False,
23 | mask: bool = False):
24 | super(Transformer, self).__init__()
25 |
26 | self.encoder_list_1 = ModuleList([Encoder(d_model=d_model,
27 | d_hidden=d_hidden,
28 | q=q,
29 | v=v,
30 | h=h,
31 | mask=mask,
32 | dropout=dropout,
33 | device=device) for _ in range(N)])
34 |
35 | self.encoder_list_2 = ModuleList([Encoder(d_model=d_model,
36 | d_hidden=d_hidden,
37 | q=q,
38 | v=v,
39 | h=h,
40 | dropout=dropout,
41 | device=device) for _ in range(N)])
42 |
43 | self.embedding_channel = torch.nn.Linear(d_channel, d_model)
44 | self.embedding_input = torch.nn.Linear(d_input, d_model)
45 |
46 | self.gate = torch.nn.Linear(d_model * d_input + d_model * d_channel, 2)
47 | self.output_linear = torch.nn.Linear(d_model * d_input + d_model * d_channel, d_output)
48 |
49 | self.pe = pe
50 | self._d_input = d_input
51 | self._d_model = d_model
52 |
53 | def forward(self, x, stage):
54 | """
55 | 前向传播
56 | :param x: 输入
57 | :param stage: 用于描述此时是训练集的训练过程还是测试集的测试过程 测试过程中均不在加mask机制
58 | :return: 输出,gate之后的二维向量,step-wise encoder中的score矩阵,channel-wise encoder中的score矩阵,step-wise embedding后的三维矩阵,channel-wise embedding后的三维矩阵,gate
59 | """
60 | # step-wise
61 | # score矩阵为 input, 默认加mask 和 pe
62 | encoding_1 = self.embedding_channel(x)
63 | input_to_gather = encoding_1
64 |
65 | if self.pe:
66 | pe = torch.ones_like(encoding_1[0])
67 | position = torch.arange(0, self._d_input).unsqueeze(-1)
68 | temp = torch.Tensor(range(0, self._d_model, 2))
69 | temp = temp * -(math.log(10000) / self._d_model)
70 | temp = torch.exp(temp).unsqueeze(0)
71 | temp = torch.matmul(position.float(), temp) # shape:[input, d_model/2]
72 | pe[:, 0::2] = torch.sin(temp)
73 | pe[:, 1::2] = torch.cos(temp)
74 |
75 | encoding_1 = encoding_1 + pe
76 |
77 | for encoder in self.encoder_list_1:
78 | encoding_1, score_input = encoder(encoding_1, stage)
79 |
80 | # channel-wise
81 | # score矩阵为channel 默认不加mask和pe
82 | encoding_2 = self.embedding_input(x.transpose(-1, -2))
83 | channel_to_gather = encoding_2
84 |
85 | for encoder in self.encoder_list_2:
86 | encoding_2, score_channel = encoder(encoding_2, stage)
87 |
88 | # 三维变二维
89 | encoding_1 = encoding_1.reshape(encoding_1.shape[0], -1)
90 | encoding_2 = encoding_2.reshape(encoding_2.shape[0], -1)
91 |
92 | # gate
93 | gate = F.softmax(self.gate(torch.cat([encoding_1, encoding_2], dim=-1)), dim=-1)
94 | encoding = torch.cat([encoding_1 * gate[:, 0:1], encoding_2 * gate[:, 1:2]], dim=-1)
95 |
96 | # 输出
97 | output = self.output_linear(encoding)
98 |
99 | return output, encoding, score_input, score_channel, input_to_gather, channel_to_gather, gate
100 |
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/mytest/DTW_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from dtw import dtw
4 |
5 | # x = np.array([2, 0, 1, 1, 2, 4, 2, 1, 2, 0]).reshape(-1, 1)
6 | # y = np.array([1, 1, 2, 4, 2, 1, 2, 0]).reshape(-1, 1)
7 | x = np.array([8, 9, 1, 9, 6, 1, 3, 5]).reshape(-1, 1)
8 | y = np.array([2, 5, 4, 6, 7, 8, 3, 7, 7, 2]).reshape(-1, 1)
9 |
10 |
11 | euclidean_norm = lambda x, y: np.abs(x - y)
12 |
13 | d, cost_matrix, acc_cost_matrix, path = dtw(x, y, dist=euclidean_norm)
14 |
15 | print('d ', d)
16 | print('cost_matrix \r\n', cost_matrix)
17 | print('acc_cost_matrix \r\n', acc_cost_matrix)
18 | print('path ', path)
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/mytest/HeatMap.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import seaborn as sns
3 | sns.set()
4 | flights_long = sns.load_dataset("flights")
5 | flights = flights_long.pivot("month", "year", "passengers")
6 | # 绘制x-y-z的热力图,比如 年-月-销量 的热力图
7 | f, ax = plt.subplots(figsize=(9, 6))
8 | #绘制热力图,还要将数值写到热力图上
9 | sns.heatmap(flights, annot=True, fmt="d", ax=ax)
10 | #设置坐标字体方向
11 | label_y = ax.get_yticklabels()
12 | plt.setp(label_y, rotation=360, horizontalalignment='right')
13 | label_x = ax.get_xticklabels()
14 | plt.setp(label_x, rotation=45, horizontalalignment='right')
15 | plt.show()
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/mytest/HeatMap_DTW.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import torch
3 | import numpy as np
4 | from dtw import dtw
5 | import matplotlib.pyplot as plt
6 | import seaborn as sns
7 | import os
8 |
9 | plt.rcParams['font.sans-serif']=['SimHei'] #显示中文标签
10 | plt.rcParams['axes.unicode_minus']=False #这两行需要手动设
11 |
12 | from scipy.io import loadmat
13 |
14 | from torch.utils.data import Dataset
15 | import torch
16 |
17 | from scipy.io import loadmat
18 |
19 |
20 | class MyDataset(Dataset):
21 | def __init__(self, path, dataset):
22 | super(MyDataset, self).__init__()
23 | self.dataset = dataset
24 | self.train_len, \
25 | self.test_len, \
26 | self.input_len, \
27 | self.channel_len, \
28 | self.output_len, \
29 | self.train_dataset, \
30 | self.train_label, \
31 | self.test_dataset, \
32 | self.test_label = self.pre_option(path)
33 |
34 | def __getitem__(self, index):
35 | if self.dataset == 'train':
36 | return self.train_dataset[index], self.train_label[index] - 1
37 | elif self.dataset == 'test':
38 | return self.test_dataset[index], self.test_label[index] - 1
39 |
40 | def __len__(self):
41 | if self.dataset == 'train':
42 | return self.train_len
43 | elif self.dataset == 'test':
44 | return self.test_len
45 |
46 | def pre_option(self, path):
47 | m = loadmat(path)
48 |
49 | # m中是一个字典 有4个key 其中最后一个键值对存储的是数据
50 | x1, x2, x3, x4 = m
51 | data = m[x4]
52 |
53 | data00 = data[0][0]
54 | # print('data00.shape', data00.shape) # () data00才到达数据的维度
55 |
56 | index_train = str(data.dtype).find('train\'')
57 | index_trainlabels = str(data.dtype).find('trainlabels')
58 | index_test = str(data.dtype).find('test\'')
59 | index_testlabels = str(data.dtype).find('testlabels')
60 | list = [index_test, index_train, index_testlabels, index_trainlabels]
61 | list = sorted(list)
62 | index_train = list.index(index_train)
63 | index_trainlabels = list.index(index_trainlabels)
64 | index_test = list.index(index_test)
65 | index_testlabels = list.index(index_testlabels)
66 |
67 | # [('trainlabels', 'O'), ('train', 'O'), ('testlabels', 'O'), ('test', 'O')] O 表示数据类型为 numpy.object
68 | train_label = data00[index_trainlabels]
69 | train_data = data00[index_train]
70 | test_label = data00[index_testlabels]
71 | test_data = data00[index_test]
72 |
73 | train_label = train_label.squeeze()
74 | train_data = train_data.squeeze()
75 | test_label = test_label.squeeze()
76 | test_data = test_data.squeeze()
77 |
78 | train_len = train_data.shape[0]
79 | test_len = test_data.shape[0]
80 | output_len = len(tuple(set(train_label)))
81 |
82 | # 时间步最大值
83 | max_lenth = 0 # 93
84 | for item in train_data:
85 | item = torch.as_tensor(item).float()
86 | if item.shape[1] > max_lenth:
87 | max_lenth = item.shape[1]
88 |
89 | for item in test_data:
90 | item = torch.as_tensor(item).float()
91 | if item.shape[1] > max_lenth:
92 | max_lenth = item.shape[1]
93 |
94 | # train_data, test_data为numpy.object 类型,不能直接对里面的numpy.ndarray进行处理
95 | train_dataset = []
96 | test_dataset = []
97 | for x1 in train_data:
98 | x1 = torch.as_tensor(x1).float()
99 | if x1.shape[1] != max_lenth:
100 | padding = torch.zeros(x1.shape[0], max_lenth - x1.shape[1])
101 | x1 = torch.cat((x1, padding), dim=1)
102 | train_dataset.append(x1)
103 |
104 | for x2 in test_data:
105 | x2 = torch.as_tensor(x2).float()
106 | if x2.shape[1] != max_lenth:
107 | padding = torch.zeros(x2.shape[0], max_lenth - x2.shape[1])
108 | x2 = torch.cat((x2, padding), dim=1)
109 | test_dataset.append(x2)
110 |
111 | # 最后维度 [数据条数,时间步数最大值,时间序列维度]
112 | train_dataset = torch.stack(train_dataset, dim=0).permute(0, 2, 1)
113 | test_dataset = torch.stack(test_dataset, dim=0).permute(0, 2, 1)
114 | train_label = torch.Tensor(train_label)
115 | test_label = torch.Tensor(test_label)
116 | channel = test_dataset[0].shape[-1]
117 | input = test_dataset[0].shape[-2]
118 |
119 | return train_len, test_len, input, channel, output_len, train_dataset, train_label, test_dataset, test_label
120 |
121 | def heatMap_channel(matrix, file_name, EPOCH):
122 | test_data = matrix[0].detach().numpy()
123 | euclidean_norm = lambda x, y: np.abs(x - y)
124 | matrix_0 = np.ones((test_data.shape[1], test_data.shape[1]))
125 | matrix_1 = np.ones((test_data.shape[1], test_data.shape[1])) # 相差度
126 | for i in range(test_data.shape[1]):
127 | for j in range(test_data.shape[1]):
128 | x = test_data[:, i]
129 | y = test_data[:, j]
130 | d, cost_matrix, acc_cost_matrix, path = dtw(x, y, dist=euclidean_norm)
131 | matrix_0[i, j] = d
132 | matrix_1[i, j] = np.mean((x - y) ** 2)
133 |
134 | sns.set()
135 | # f, ax = plt.subplots(figsize=(9, 6))
136 | sns.heatmap(matrix_0, annot=True)
137 | plt.title('CHANNEL DTW')
138 | if os.path.exists(f'../heatmap_figure/{file_name}'):
139 | plt.savefig(f'../heatmap_figure/{file_name}/channel DTW EPOCH:{EPOCH}.jpg')
140 | else:
141 | os.makedirs(f'../heatmap_figure/{file_name}')
142 | plt.savefig(f'../heatmap_figure/{file_name}/channel DTW EPOCH:{EPOCH}.jpg')
143 | sns.heatmap(matrix_1, annot=True)
144 | plt.title('CHANNEL difference')
145 | plt.savefig(f'../heatmap_figure/{file_name}/channel difference EPOCH:{EPOCH}.jpg')
146 |
147 | def heatMap_input(matrix, file_name, EPOCH):
148 | test_data = matrix[0].detach().numpy()
149 | euclidean_norm = lambda x, y: np.abs(x - y)
150 | matrix_0 = np.ones((test_data.shape[0], test_data.shape[0])) # DTW
151 | matrix_1 = np.ones((test_data.shape[0], test_data.shape[0])) # 相差度
152 | for i in range(test_data.shape[0]):
153 | for j in range(test_data.shape[0]):
154 | x = test_data[i, :]
155 | y = test_data[j, :]
156 | d, cost_matrix, acc_cost_matrix, path = dtw(x, y, dist=euclidean_norm)
157 | matrix_0[i, j] = d
158 | matrix_1[i, j] = np.mean((x - y) ** 2)
159 |
160 | sns.set()
161 | # f, ax = plt.subplots(figsize=(9, 6))
162 | sns.heatmap(matrix_0, annot=True)
163 | plt.title('INPUT DTW')
164 | if os.path.exists(f'../heatmap_figure/{file_name}'):
165 | plt.savefig(f'../heatmap_figure/{file_name}/input DTW EPOCH:{EPOCH}.jpg')
166 | else:
167 | os.makedirs(f'../heatmap_figure/{file_name}')
168 | plt.savefig(f'../heatmap_figure/{file_name}/input DTW EPOCH:{EPOCH}.jpg')
169 | sns.heatmap(matrix_1, annot=True)
170 | plt.title('INPUT difference')
171 | plt.savefig(f'../heatmap_figure/{file_name}/input difference EPOCH:{EPOCH}.jpg')
172 |
173 | def heatMap_score(matrix_input, matrix_channel, file_name, EPOCH):
174 | score_input = matrix_input[0].detach().numpy()
175 | score_channel = matrix_channel[0].detach().numpy()
176 | sns.set()
177 | # f, ax = plt.subplots(figsize=(9, 6))
178 | sns.heatmap(score_input, annot=True)
179 | plt.title('SCORE INPUT')
180 | if os.path.exists(f'../heatmap_figure/{file_name}'):
181 | plt.savefig(f'../heatmap_figure/{file_name}/score input EPOCH:{EPOCH}.jpg')
182 | else:
183 | os.makedirs(f'../heatmap_figure/{file_name}')
184 | plt.savefig(f'../heatmap_figure/{file_name}/ score input EPOCH:{EPOCH}.jpg')
185 | sns.heatmap(score_channel, annot=True)
186 | plt.title('SCORE CHANNEL')
187 | plt.savefig(f'../heatmap_figure/{file_name}/score channel EPOCH:{EPOCH}.jpg')
188 |
189 |
190 | if __name__ == '__main__':
191 | path = 'E:\\PyCharmWorkSpace\\mtsdata\\JapaneseVowels\\JapaneseVowels.mat' # lenth=270 input=29 channel=12 output=9
192 |
193 | dataset = MyDataset(path, 'train')
194 | train_dataset = dataset.train_dataset
195 | print(train_dataset.shape)
196 | test_data = train_dataset[0].numpy()
197 | print(test_data[12, :])
198 | # step_3 = train_dataset[0, 3, :].numpy()
199 | # step_12 = train_dataset[0, 12, :].numpy()
200 | # step_15 = train_dataset[0, 15, :].numpy()
201 | # step_25 = train_dataset[0, 25, :].numpy()
202 | # step_21 = train_dataset[0, 21, :].numpy()
203 | # step_27 = train_dataset[0, 27, :].numpy()
204 | #
205 | # # print(step_25.shape)
206 | # print(step_15)
207 | # print(step_25)
208 | # print(step_27)
209 |
210 | euclidean_norm = lambda x, y: np.abs(x - y)
211 |
212 | # d1, cost_matrix, acc_cost_matrix, path = dtw(step_3, step_12, dist=euclidean_norm)
213 | # d2, cost_matrix, acc_cost_matrix, path = dtw(step_15, step_25, dist=euclidean_norm)
214 | # d3, cost_matrix, acc_cost_matrix, path = dtw(step_3, step_21, dist=euclidean_norm)
215 | # d4, cost_matrix, acc_cost_matrix, path = dtw(step_15, step_27, dist=euclidean_norm)
216 | #
217 | # print(d1)
218 | # print(d2)
219 | # print(d3)
220 | # print(d4)
221 |
222 | print(test_data.shape)
223 | # matrix = np.ones((test_data.shape[0], test_data.shape[0])) # DTW
224 | # matrix_1 = np.ones((test_data.shape[0], test_data.shape[0])) # 欧氏距离
225 | # matrix_2 = np.ones((test_data.shape[0], test_data.shape[0])) # 相差度
226 | # matrix_3 = np.ones((test_data.shape[0], test_data.shape[0])) # 相差度
227 | matrix = np.ones((test_data.shape[1], test_data.shape[1])) # DTW
228 | matrix_1 = np.ones((test_data.shape[1], test_data.shape[1])) # 欧氏距离
229 | matrix_2 = np.ones((test_data.shape[1], test_data.shape[1])) # 相差度
230 | matrix_3 = np.ones((test_data.shape[1], test_data.shape[1])) # 点成
231 | for i in range(test_data.shape[1]):
232 | for j in range(test_data.shape[1]):
233 | # x = test_data[i, :]
234 | # y = test_data[j, :]
235 | x = test_data[:, i]
236 | y = test_data[:, j]
237 | d, cost_matrix, acc_cost_matrix, path = dtw(x, y, dist=euclidean_norm)
238 | matrix[i, j] = d
239 | matrix_1[i, j] = np.sum(np.abs(x-y))
240 | matrix_2[i, j] = np.mean((x - y) ** 2)
241 | matrix_3[i, j] = float(x.dot(y))
242 |
243 | # print(matrix)
244 |
245 | import numpy as np
246 | import seaborn as sns
247 | import matplotlib.pyplot as plt
248 | sns.set()
249 | # np.random.seed(0)
250 | # uniform_data = np.random.rand(10, 12)
251 | # ax = sns.heatmap(matrix)
252 | # f, ax = plt.subplots(figsize=(9, 6))
253 |
254 | sns.heatmap(matrix, annot=True)
255 | plt.title('input')
256 | # os.mkdir('../heatmap_figure/lala')
257 | if os.path.exists('../heatmap_figure/lala'):
258 | plt.savefig(f'../heatmap_figure/lala/1.jpg')
259 | else:
260 | os.makedirs('../heatmap_figure/lala')
261 | plt.savefig(f'../heatmap_figure/lala/1.jpg')
262 | plt.show()
263 |
264 | plt.title('channel')
265 | sns.heatmap(matrix_3, annot=True)
266 | plt.savefig('2.jpg')
267 | plt.show()
268 |
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/mytest/__pycache__/HeatMap_DTW.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/mytest/__pycache__/HeatMap_DTW.cpython-37.pyc
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/mytest/kmeans_test.py:
--------------------------------------------------------------------------------
1 | # import cPickle
2 | # X,y = cPickle.load(open('data.pkl','r')) #X和y都是numpy.ndarray类型
3 | # X.shape #输出(1000,2)
4 | # y.shape #输出(1000,)对应每个样本的真实标签
5 | from dataset_process.dataset_process import MyDataset
6 |
7 | # path = 'E:\\PyCharmWorkSpace\\mtsdata\\ECG\\ECG.mat' # lenth=100 input=152 channel=2 output=2
8 | # path = 'E:\\PyCharmWorkSpace\\mtsdata\\JapaneseVowels\\JapaneseVowels.mat' # lenth=270 input=29 channel=12 output=9
9 | path = 'E:\\PyCharmWorkSpace\\mtsdata\\KickvsPunch\\KickvsPunch.mat' # lenth=10 input=841 channel=62 output=2
10 |
11 |
12 |
13 | dataset = MyDataset(path, 'train')
14 | X = dataset.train_dataset
15 | # X = torch.mean(X, dim=1).numpy()
16 | X = X.reshape(X.shape[0], -1).numpy()
17 | y = dataset.train_label.numpy()
18 |
19 |
20 | import numpy as np
21 | import matplotlib.pyplot as plt
22 | from utils.kmeans import KMeans
23 | def draw(X, Y):
24 | clf = KMeans(n_clusters=2, initCent='random' ,max_iter=100)
25 | clf.fit(X)
26 | cents = clf.centroids#质心
27 | labels = clf.labels#样本点被分配到的簇的索引
28 | sse = clf.sse
29 | #画出聚类结果,每一类用一种颜色
30 | colors = ['b','g','r','k','c','m','y','#e24fff','#524C90','#845868']
31 | n_clusters = 2
32 | for i in range(n_clusters):
33 | index = np.nonzero(labels==i)[0]
34 | x0 = X[index,0]
35 | x1 = X[index,1]
36 | y_i = Y[index]
37 | for j in range(len(x0)):
38 | plt.text(x0[j],x1[j],str(int(y_i[j])),color=colors[i],\
39 | fontdict={'weight': 'bold', 'size': 9})
40 | plt.scatter(cents[i,0],cents[i,1],marker='x',color=colors[i],linewidths=12)
41 | plt.title("SSE={:.2f}".format(sse))
42 | plt.axis([-30,30,-30,30])
43 | plt.show()
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/result_figure/CharacterTrajectories 96.6% Adam epoch=100 batch=4 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/result_figure/CharacterTrajectories 96.6% Adam epoch=100 batch=4 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/result_figure/ECG 86.0% Adam epoch=10 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/result_figure/ECG 86.0% Adam epoch=10 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/result_figure/ECG 86.0% Adam epoch=20 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/result_figure/ECG 86.0% Adam epoch=20 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/result_figure/ECG 87.0% Adam epoch=50 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/result_figure/ECG 87.0% Adam epoch=50 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/result_figure/JapaneseVowels 98.11% Adagrad epoch=100 batch=3 lr=0.0001 pe=True mask=True [512,8,8,8,8,0.2].png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/result_figure/JapaneseVowels 98.11% Adagrad epoch=100 batch=3 lr=0.0001 pe=True mask=True [512,8,8,8,8,0.2].png
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/run.py:
--------------------------------------------------------------------------------
1 | # @Time : 2021/01/22 25:16
2 | # @Author : SY.M
3 | # @FileName: run.py
4 |
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from dataset_process.dataset_process import MyDataset
8 | import torch.optim as optim
9 | from time import time
10 | from tqdm import tqdm
11 | import os
12 |
13 | from module.transformer import Transformer
14 | from module.loss import Myloss
15 | from utils.random_seed import setup_seed
16 | from utils.visualization import result_visualization
17 |
18 | # from mytest.gather.main import draw
19 |
20 | setup_seed(30) # 设置随机数种子
21 | reslut_figure_path = 'result_figure' # 结果图像保存路径
22 |
23 | # 数据集路径选择
24 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\AUSLAN\\AUSLAN.mat' # lenth=1140 input=136 channel=22 output=95
25 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\CharacterTrajectories\\CharacterTrajectories.mat'
26 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\CMUsubject16\\CMUsubject16.mat' # lenth=29,29 input=580 channel=62 output=2
27 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\ECG\\ECG.mat' # lenth=100 input=152 channel=2 output=2
28 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\JapaneseVowels\\JapaneseVowels.mat' # lenth=270 input=29 channel=12 output=9
29 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\Libras\\Libras.mat' # lenth=180 input=45 channel=2 output=15
30 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\UWave\\UWave.mat' # lenth=4278 input=315 channel=3 output=8
31 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\KickvsPunch\\KickvsPunch.mat' # lenth=10 input=841 channel=62 output=2
32 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\NetFlow\\NetFlow.mat' # lenth=803 input=997 channel=4 output=只有1和13
33 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\ArabicDigits\\ArabicDigits.mat' # lenth=6600 input=93 channel=13 output=10
34 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\PEMS\\PEMS.mat'
35 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\Wafer\\Wafer.mat'
36 | path = 'E:\PyCharmWorkSpace\dataset\\MTS_dataset\\WalkvsRun\\WalkvsRun.mat'
37 |
38 | test_interval = 5 # 测试间隔 单位:epoch
39 | draw_key = 1 # 大于等于draw_key才会保存图像
40 | file_name = path.split('\\')[-1][0:path.split('\\')[-1].index('.')] # 获得文件名字
41 |
42 | # 超参数设置
43 | EPOCH = 100
44 | BATCH_SIZE = 3
45 | LR = 1e-4
46 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 选择设备 CPU or GPU
47 | print(f'use device: {DEVICE}')
48 |
49 | d_model = 512
50 | d_hidden = 1024
51 | q = 8
52 | v = 8
53 | h = 8
54 | N = 8
55 | dropout = 0.2
56 | pe = True # # 设置的是双塔中 score=pe score=channel默认没有pe
57 | mask = True # 设置的是双塔中 score=input的mask score=channel默认没有mask
58 | # 优化器选择
59 | optimizer_name = 'Adagrad'
60 |
61 | train_dataset = MyDataset(path, 'train')
62 | test_dataset = MyDataset(path, 'test')
63 | train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
64 | test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)
65 |
66 | DATA_LEN = train_dataset.train_len # 训练集样本数量
67 | d_input = train_dataset.input_len # 时间部数量
68 | d_channel = train_dataset.channel_len # 时间序列维度
69 | d_output = train_dataset.output_len # 分类类别
70 |
71 | # 维度展示
72 | print('data structure: [lines, timesteps, features]')
73 | print(f'train data size: [{DATA_LEN, d_input, d_channel}]')
74 | print(f'mytest data size: [{train_dataset.test_len, d_input, d_channel}]')
75 | print(f'Number of classes: {d_output}')
76 |
77 | # 创建Transformer模型
78 | net = Transformer(d_model=d_model, d_input=d_input, d_channel=d_channel, d_output=d_output, d_hidden=d_hidden,
79 | q=q, v=v, h=h, N=N, dropout=dropout, pe=pe, mask=mask, device=DEVICE).to(DEVICE)
80 | # 创建loss函数 此处使用 交叉熵损失
81 | loss_function = Myloss()
82 | if optimizer_name == 'Adagrad':
83 | optimizer = optim.Adagrad(net.parameters(), lr=LR)
84 | elif optimizer_name == 'Adam':
85 | optimizer = optim.Adam(net.parameters(), lr=LR)
86 |
87 | # 用于记录准确率变化
88 | correct_on_train = []
89 | correct_on_test = []
90 | # 用于记录损失变化
91 | loss_list = []
92 | time_cost = 0
93 |
94 |
95 | # 测试函数
96 | def test(dataloader, flag='test_set'):
97 | correct = 0
98 | total = 0
99 | with torch.no_grad():
100 | net.eval()
101 | for x, y in dataloader:
102 | x, y = x.to(DEVICE), y.to(DEVICE)
103 | y_pre, _, _, _, _, _, _ = net(x, 'test')
104 | _, label_index = torch.max(y_pre.data, dim=-1)
105 | total += label_index.shape[0]
106 | correct += (label_index == y.long()).sum().item()
107 | if flag == 'test_set':
108 | correct_on_test.append(round((100 * correct / total), 2))
109 | elif flag == 'train_set':
110 | correct_on_train.append(round((100 * correct / total), 2))
111 | print(f'Accuracy on {flag}: %.2f %%' % (100 * correct / total))
112 |
113 | return round((100 * correct / total), 2)
114 |
115 |
116 | # 训练函数
117 | def train():
118 | net.train()
119 | max_accuracy = 0
120 | pbar = tqdm(total=EPOCH)
121 | begin = time()
122 | for index in range(EPOCH):
123 | for i, (x, y) in enumerate(train_dataloader):
124 | optimizer.zero_grad()
125 |
126 | y_pre, _, _, _, _, _, _ = net(x.to(DEVICE), 'train')
127 |
128 | loss = loss_function(y_pre, y.to(DEVICE))
129 |
130 | print(f'Epoch:{index + 1}:\t\tloss:{loss.item()}')
131 | loss_list.append(loss.item())
132 |
133 | loss.backward()
134 |
135 | optimizer.step()
136 |
137 | if ((index + 1) % test_interval) == 0:
138 | current_accuracy = test(test_dataloader)
139 | test(train_dataloader, 'train_set')
140 | print(f'当前最大准确率\t测试集:{max(correct_on_test)}%\t 训练集:{max(correct_on_train)}%')
141 |
142 | if current_accuracy > max_accuracy:
143 | max_accuracy = current_accuracy
144 | torch.save(net, f'saved_model/{file_name} batch={BATCH_SIZE}.pkl')
145 |
146 | pbar.update()
147 |
148 | os.rename(f'saved_model/{file_name} batch={BATCH_SIZE}.pkl',
149 | f'saved_model/{file_name} {max_accuracy} batch={BATCH_SIZE}.pkl')
150 |
151 | end = time()
152 | time_cost = round((end - begin) / 60, 2)
153 |
154 | # 结果图
155 | result_visualization(loss_list=loss_list, correct_on_test=correct_on_test, correct_on_train=correct_on_train,
156 | test_interval=test_interval,
157 | d_model=d_model, q=q, v=v, h=h, N=N, dropout=dropout, DATA_LEN=DATA_LEN, BATCH_SIZE=BATCH_SIZE,
158 | time_cost=time_cost, EPOCH=EPOCH, draw_key=draw_key, reslut_figure_path=reslut_figure_path,
159 | file_name=file_name,
160 | optimizer_name=optimizer_name, LR=LR, pe=pe, mask=mask)
161 |
162 |
163 | if __name__ == '__main__':
164 | train()
165 |
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/run_with_saved_model.py:
--------------------------------------------------------------------------------
1 | # @Time : 2021/01/22 25:50
2 | # @Author : SY.M
3 | # @FileName: run_with_saved_model.py
4 |
5 |
6 | import torch
7 | print('当前使用的pytorch版本:', torch.__version__)
8 | from utils.random_seed import setup_seed
9 | from torch.utils.data import DataLoader
10 | from dataset_process.dataset_process import MyDataset
11 | from utils.heatMap import heatMap_all
12 | from utils.TSNE import gather_by_tsne
13 | from utils.TSNE import gather_all_by_tsne
14 | import numpy as np
15 | from utils.colorful_line import draw_colorful_line
16 |
17 | setup_seed(30)
18 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19 | # 数据集维度展示
20 | # ArabicDigits length=6600 input=93 channel=13 output=10
21 | # AUSLAN length=1140 input=136 channel=22 output=95
22 | # CharacterTrajectories
23 | # CMUsubject16 length=29,29 input=580 channel=62 output=2
24 | # ECG length=100 input=152 channel=2 output=2
25 | # JapaneseVowels length=270 input=29 channel=12 output=9
26 | # Libras length=180 input=45 channel=2 output=15
27 | # UWave length=4278 input=315 channel=3 output=8
28 | # KickvsPunch length=10 input=841 channel=62 output=2
29 | # NetFlow length=803 input=997 channel=4 output=只有1和13 需要修改数据集处理的代码
30 | # Wafer length=803 input=997 channel=4
31 |
32 | # 选择要跑的模型
33 | save_model_path = 'saved_model/ECG 91.0 batch=2.pkl'
34 | file_name = save_model_path.split('/')[-1].split(' ')[0]
35 | path = f'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\{file_name}\\{file_name}.mat' # 拼装数据集路径
36 |
37 | # 绘制HeatMap的命名准备工作
38 | ACCURACY = save_model_path.split('/')[-1].split(' ')[1] # 使用的模型的准确率
39 | BATCH_SIZE = int(save_model_path[save_model_path.find('=')+1:save_model_path.rfind('.')]) # 使用的模型的batch_size
40 | heatMap_or_not = False # 是否绘制Score矩阵的HeatMap图
41 | gather_or_not = False # 是否绘制单个样本的step和channel上的聚类图
42 | gather_all_or_not = True # 是否绘制所有样本在特征提取后的聚类图
43 |
44 | # 加载模型
45 | net = torch.load(save_model_path, map_location=torch.device('cpu')) # map_location 设置使用的设备,可能是因为原来的pkl是在colab上用GPU跑的
46 | # 加载测试集数据
47 | test_dataset = MyDataset(path, 'test')
48 | test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)
49 |
50 | print(f'step为最大值的sample个数:{len(test_dataset.max_length_sample_inTest)}')
51 | if len(test_dataset.max_length_sample_inTest) == 0:
52 | gather_or_not = False
53 | heatMap_or_not = False
54 | print('测试集中没有step为最大值的样本, 将不能绘制make sense 的heatmap 和 gather 图 可尝试换一个数据集')
55 |
56 | correct = 0
57 | total = 0
58 | with torch.no_grad():
59 | all_sample_X = []
60 | all_sample_Y = []
61 | for x, y in test_dataloader:
62 | x, y = x.to(DEVICE), y.to(DEVICE)
63 | y_pre, encoding, score_input, score_channel, gather_input, gather_channel, gate = net(x.to(DEVICE), 'test')
64 |
65 | all_sample_X.append(encoding)
66 | all_sample_Y.append(y)
67 | if heatMap_or_not:
68 | for index, sample in enumerate(test_dataset.max_length_sample_inTest):
69 | if sample.numpy().tolist() in x.numpy().tolist():
70 | target_index = x.numpy().tolist().index(sample.numpy().tolist())
71 | print('正在绘制heatmap图...')
72 | heatMap_all(score_input[target_index], score_channel[target_index], sample, 'heatmap_figure_in_test', file_name, ACCURACY, index)
73 | print('heatmap图绘制完成!')
74 | if gather_or_not:
75 | for index, sample in enumerate(test_dataset.max_length_sample_inTest):
76 | if sample.numpy().tolist() in x.numpy().tolist():
77 | target_index = x.numpy().tolist().index(sample.numpy().tolist())
78 | print('正在绘制gather图...')
79 | gather_by_tsne(gather_input[target_index].numpy(), np.arange(gather_input[target_index].shape[0]), index, file_name+' input_gather')
80 | gather_by_tsne(gather_channel[target_index].numpy(), np.arange(gather_channel[target_index].shape[0]), index, file_name+' channel_gather')
81 | print('gather图绘制完成!')
82 | draw_data = x[target_index].transpose(-1, -2)[0].numpy()
83 | draw_colorful_line(draw_data)
84 | gather_or_not = False
85 |
86 | _, label_index = torch.max(y_pre.data, dim=-1)
87 | total += label_index.shape[0]
88 | correct += (label_index == y.long()).sum().item()
89 |
90 | if gather_all_or_not:
91 | all_sample_X = torch.cat(all_sample_X, dim=0).numpy()
92 | all_sample_Y = torch.cat(all_sample_Y, dim=0).numpy()
93 | print('正在绘制gather图...')
94 | gather_all_by_tsne(all_sample_X, all_sample_Y, test_dataset.output_len, file_name+' all_sample_gather')
95 | print('gather图绘制完成!')
96 |
97 | print(f'Accuracy: %.2f %%' % (100 * correct / total))
98 |
99 |
100 |
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/utils/TSNE.py:
--------------------------------------------------------------------------------
1 | from dataset_process.dataset_process import MyDataset
2 | import torch
3 | import numpy as np
4 | import os
5 | import matplotlib.pyplot as plt
6 |
7 | # path = 'E:\\PyCharmWorkSpace\\mtsdata\\ECG\\ECG.mat' # lenth=100 input=152 channel=2 output=2
8 | #
9 | # dataset = MyDataset(path, 'train')
10 | # X = dataset.train_dataset
11 | # X = torch.mean(X, dim=1).numpy()
12 | # # X = X.reshape(X.shape[0], -1).numpy()
13 | # y = dataset.train_label.numpy()
14 |
15 | from matplotlib import cm
16 |
17 |
18 | def plot_with_labels(lowDWeights, labels, kinds, file_name):
19 | plt.cla()
20 | # 降到二维了,分别给x和y
21 | X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
22 | # 遍历每个点以及对应标签
23 | for x, y, s in zip(X, Y, labels):
24 | c = cm.rainbow(int(255/kinds * s)) # 为了使得颜色有区分度,把0-255颜色区间分为9分,然后把标签映射到一个区间
25 | plt.text(x, y, s, backgroundcolor=c, fontsize=6)
26 | plt.xlim(X.min(), X.max())
27 | plt.ylim(Y.min(), Y.max())
28 |
29 | plt.title('Clustering Step-Wise after Embedding')
30 | plt.rcParams['figure.figsize'] = (10.0, 10.0) # 设置figure_size尺寸
31 |
32 | if os.path.exists(f'gather_figure/{file_name.split(" ")[0]}') == False:
33 | os.makedirs(f'gather_figure/{file_name.split(" ")[0]}')
34 |
35 | plt.savefig(f'gather_figure/{file_name.split(" ")[0]}/{file_name}.jpg', dpi=600)
36 | # plt.show()
37 | plt.close()
38 |
39 | def plot_only(lowDWeights, labels, index, file_name):
40 | """
41 | 绘制聚类图并为标签打上颜色
42 | :param lowDWeights: 将为之后的用于绘制聚类图的数据
43 | :param labels: lowDWeights对应的标签
44 | :param index: 用于命名文件是进行区分 防止覆盖
45 | :param file_name: 文件名称和聚类的方式
46 | :return: None
47 | """
48 | plt.cla()
49 | # 降到二维了,分别给x和y
50 | X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
51 | # 遍历每个点以及对应标签
52 | # 聚类图中自定义的颜色的绘制请在下面for循环中完成
53 | for x, y, s in zip(X, Y, labels):
54 | position = 255
55 | if x < -850:
56 | position = 255
57 | elif 0.5*x - 225 < y:
58 | position = 0
59 | elif x < 1500:
60 | position = 50
61 | else:
62 | position = 100
63 |
64 | # c = cm.rainbow(int(255/9 * s)) # 为了使得颜色有区分度,把0-255颜色区间分为9分,然后把标签映射到一个区间
65 | c = cm.rainbow(position) # 为了使得颜色有区分度,把0-255颜色区间分为9分,然后把标签映射到一个区间
66 | plt.text(x, y, s, backgroundcolor=c, fontsize=6)
67 | plt.xlim(X.min(), X.max())
68 | plt.ylim(Y.min(), Y.max())
69 | # plt.title('Clustering Step-Wise after Embedding')
70 |
71 | plt.rcParams['figure.figsize'] = (10.0, 10.0) # 设置figure_size尺寸
72 |
73 | if os.path.exists(f'gather_figure/{file_name.split(" ")[0]}') == False:
74 | os.makedirs(f'gather_figure/{file_name.split(" ")[0]}')
75 |
76 | plt.savefig(f'gather_figure/{file_name.split(" ")[0]}/{file_name} {index}.jpg', dpi=600)
77 | # plt.show()
78 | plt.close()
79 |
80 |
81 | from sklearn.manifold import TSNE
82 |
83 |
84 | def gather_by_tsne(X: np.ndarray,
85 | Y: np.ndarray,
86 | index: int,
87 | file_name: str):
88 | tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=4000) # TSNE降维,降到2
89 | low_dim_embs = tsne.fit_transform(X[:, :])
90 | labels = Y[:]
91 | plot_only(low_dim_embs, labels, index, file_name)
92 |
93 |
94 | def gather_all_by_tsne(X: np.ndarray,
95 | Y: np.ndarray,
96 | kinds: int,
97 | file_name: str):
98 | """
99 | 对gate之后的二维数据进行聚类
100 | :param X: 聚类数据 2维数据
101 | :param Y: 聚类数据对应标签
102 | :param kinds: 分类数
103 | :param file_name: 用于文件命名
104 | :return: None
105 | """
106 | tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=4000) # TSNE降维,降到2
107 | low_dim_embs = tsne.fit_transform(X[:, :])
108 | labels = Y[:]
109 | plot_with_labels(low_dim_embs, labels, kinds, file_name)
110 |
111 |
112 | if __name__ == '__main__':
113 | path = 'E:\\PyCharmWorkSpace\\mtsdata\\ECG\\ECG.mat' # lenth=100 input=152 channel=2 output=2
114 |
115 | dataset = MyDataset(path, 'train')
116 | X = dataset.train_dataset
117 | X = torch.mean(X, dim=1).numpy()
118 | # X = X.reshape(X.shape[0], -1).numpy()
119 | Y = dataset.train_label.numpy()
120 | gather_by_tsne(X, Y)
121 |
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/utils/__pycache__/TSNE.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/utils/__pycache__/TSNE.cpython-37.pyc
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/utils/__pycache__/colorful_line.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/utils/__pycache__/colorful_line.cpython-37.pyc
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/utils/__pycache__/draw_line.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/utils/__pycache__/draw_line.cpython-37.pyc
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/utils/__pycache__/heatMap.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/utils/__pycache__/heatMap.cpython-37.pyc
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/utils/__pycache__/random_seed.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/utils/__pycache__/random_seed.cpython-37.pyc
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/utils/__pycache__/visualization.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/utils/__pycache__/visualization.cpython-37.pyc
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/utils/colorful_line.py:
--------------------------------------------------------------------------------
1 | from matplotlib.collections import LineCollection
2 | import numpy as np
3 | import math
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | def draw_colorful_line(draw_data: np.ndarray):
8 | # 定义颜色列表 与各个点一一对应 维度与点的个数相同
9 | colors = []
10 | length = len(draw_data)
11 | x = np.arange(length)
12 | # 定义属于各个颜色的点
13 | red = [1, 2, 3]
14 | blue = [4, 5, 6]
15 | cyan = [7, 8, 9]
16 | # 填充颜色列表
17 | for i in range(length):
18 | if i in red:
19 | colors.append('red')
20 | elif i in blue:
21 | colors.append('blue')
22 | elif i in cyan:
23 | colors.append('cyan')
24 | else:
25 | colors.append('purple')
26 | y = draw_data
27 |
28 | points = np.array([x, y]).T.reshape(-1, 1, 2) # shape 152,1,2
29 | # 分片 因为颜色不能画在点上 而是画在线段上
30 | segments = np.concatenate([points[:-1], points[1:]], axis=1) # shape 151,2,2
31 | lc = LineCollection(segments, color=colors)
32 | ax = plt.axes()
33 | ax.set_xlim(0, length)
34 | ax.set_ylim(min(y), max(y))
35 | ax.add_collection(lc)
36 | plt.show()
37 | plt.close()
38 |
39 |
40 | if __name__ == '__main__':
41 | pi = 3.1415
42 |
43 | x = np.linspace(0, 4 * pi, 100)
44 | y = [math.cos(xx) for xx in x]
45 | lwidths = abs(x)
46 | color = []
47 | for i in range(len(y)):
48 | if i < 50:
49 | color.append('#FF0000')
50 | else:
51 | color.append('#000000')
52 | print(color)
53 | # print(x)
54 | # print(y)
55 | print('--------------------------------------')
56 | points = np.array([x, y]).T.reshape(-1, 1, 2)
57 | print(np.array([x, y]).shape)
58 | print(points.shape)
59 | print('--------------------------------------')
60 | print(points[:-1].shape)
61 | print(points[1:].shape)
62 | segments = np.concatenate([points[:-1], points[1:]], axis=1)
63 | print(segments.shape)
64 | lc = LineCollection(segments, linewidths=lwidths, color=color)
65 |
66 | ax = plt.axes()
67 | ax.set_xlim(min(x), max(x))
68 | ax.set_ylim(min(y), max(y))
69 | ax.add_collection(lc)
70 | plt.show()
71 | plt.close()
72 |
73 | '''
74 | fig, a = plt.subplots()
75 | a.add_collection(lc)
76 | a.set_xlim(0, 4*pi)
77 | a.set_ylim(-1.1, 1.1)
78 | fig.show()
79 | '''
80 |
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/utils/draw_line.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 |
4 | def draw_line(list_1, list_2, list_3):
5 | plt.style.use('seaborn')
6 | fig = plt.figure() # 创建基础图
7 | ax1 = fig.add_subplot(211)
8 | ax2 = fig.add_subplot(212)
9 |
10 | ax1.plot(list_3, color='red', label='Target_Weight_for_step')
11 | ax1.set_xlabel('epoch')
12 | ax1.set_ylabel('value')
13 | ax1.set_title('One Sample Weight for Step')
14 |
15 | ax2.plot(list_1, color='red', label='Weight_for_step')
16 | ax2.plot(list_2, color='blue', label='Weight_for_channel')
17 | ax2.set_xlabel('epoch')
18 | ax2.set_ylabel('value')
19 | ax2.set_title('Mean Weight Allocation')
20 |
21 | plt.legend(loc='best')
22 | plt.show()
23 |
24 | def draw_heatmap_anylasis(sample):
25 | channel_00 = sample[:, 0].numpy()
26 | channel_11 = sample[:, 1].numpy()
27 | sample = sample.transpose(-1, -2)
28 | channel_0 = sample[0, :].numpy().tolist()
29 | channel_1 = sample[1, :].numpy().tolist()
30 | channel_6 = sample[6, :].numpy().tolist()
31 | channel_4 = sample[4, :].numpy().tolist()
32 | channel_3 = sample[3, :].numpy().tolist()
33 | channel_7 = sample[7, :].numpy().tolist()
34 | channel_11 = sample[11, :].numpy().tolist()
35 | channel_9 = sample[9, :].numpy().tolist()
36 |
37 | ax1 = plt.subplot(511)
38 | ax2 = plt.subplot(512)
39 | ax3 = plt.subplot(513)
40 | ax4 = plt.subplot(514)
41 | ax5 = plt.subplot(515)
42 | ax1.plot(channel_11, color='red', label='channel_11')
43 | ax1.plot(channel_7, label='channel_7')
44 | ax2.plot(channel_11, color='red', label='channel_11')
45 | ax2.plot(channel_6, label='channel_6')
46 | ax3.plot(channel_11, color='red', label='channel_11')
47 | ax3.plot(channel_3, label='channel_3')
48 | ax4.plot(channel_3, color='red', label='channel_3')
49 | ax4.plot(channel_4, label='channel_4')
50 | ax5.plot(channel_9, color='red', label='channel_6')
51 | ax5.plot(channel_0, label='channel_0')
52 | ax1.legend(loc='best')
53 | ax2.legend(loc='best')
54 | ax3.legend(loc='best')
55 | ax4.legend(loc='best')
56 | plt.suptitle('Feature-wise Series Compare')
57 | plt.show()
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/utils/heatMap.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from dtw import dtw
3 | import matplotlib.pyplot as plt
4 | import seaborn as sns
5 | import os
6 | import torch
7 |
8 | def heatMap_all(score_input: torch.Tensor, # 二维数据
9 | score_channel: torch.Tensor, # 二维数据
10 | x: torch.Tensor, # 二维数据
11 | save_root: str,
12 | file_name: str,
13 | accuracy: str,
14 | index: int) -> None:
15 | score_channel = score_channel.detach().numpy()
16 | score_input = score_input.detach().numpy()
17 | draw_data = x.detach().numpy()
18 |
19 | euclidean_norm = lambda x, y: np.abs(x - y) # 用于计算DTW使用的函数,此处是一个计算欧氏距离的函数
20 |
21 | matrix_00 = np.ones((draw_data.shape[1], draw_data.shape[1])) # 用于记录channel之间DTW值的矩阵
22 | # matrix_01 = np.ones((draw_data.shape[1], draw_data.shape[1])) # 用于记录channel之间相差度的矩阵
23 |
24 | # matrix_10 = np.ones((draw_data.shape[0], draw_data.shape[0])) # 用于记录input之间DTW值的矩阵
25 | matrix_11 = np.ones((draw_data.shape[0], draw_data.shape[0])) # 用于记录input之间相差度值的矩阵
26 |
27 | for i in range(draw_data.shape[0]):
28 | for j in range(draw_data.shape[0]):
29 | x = draw_data[i, :]
30 | y = draw_data[j, :]
31 | # d, cost_matrix, acc_cost_matrix, path = dtw(x, y, dist=euclidean_norm)
32 | # matrix_10[i, j] = d
33 | matrix_11[i, j] = np.sqrt(np.sum((x - y) ** 2))
34 |
35 | draw_data = draw_data.transpose(-1, -2)
36 | for i in range(draw_data.shape[0]):
37 | for j in range(draw_data.shape[0]):
38 | x = draw_data[i, :]
39 | y = draw_data[j, :]
40 | d, cost_matrix, acc_cost_matrix, path = dtw(x, y, dist=euclidean_norm)
41 | matrix_00[i, j] = d
42 | # matrix_01[i, j] = np.mean((x - y) ** 2)
43 |
44 | plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置figure_size尺寸
45 | plt.subplot(221)
46 | sns.heatmap(score_channel, cmap="YlGnBu", vmin=0)
47 | # plt.title('channel-wise attention')
48 |
49 | plt.subplot(222)
50 | sns.heatmap(matrix_00, cmap="YlGnBu", vmin=0)
51 | # plt.title('channel-wise DTW')
52 |
53 | plt.subplot(223)
54 | sns.heatmap(score_input, cmap="YlGnBu", vmin=0)
55 | # plt.title('step-wise attention')
56 |
57 | plt.subplot(224)
58 | sns.heatmap(matrix_11, cmap="YlGnBu", vmin=0)
59 | # plt.title('step-wise L2 distance')
60 |
61 | # plt.suptitle(f'{file_name.lower()}')
62 |
63 | if os.path.exists(f'{save_root}/{file_name}') == False:
64 | os.makedirs(f'{save_root}/{file_name}')
65 | plt.savefig(f'{save_root}/{file_name}/{file_name} accuracy={accuracy} {index}.jpg', dpi=400)
66 |
67 | # plt.show()
68 | plt.close()
69 |
70 |
71 | if __name__ == '__main__':
72 | matrix = torch.Tensor(range(24)).reshape(2, 3, 4)
73 | print(matrix.shape)
74 | file_name = 'lall'
75 | epcoh = 1
76 |
77 | data_channel = matrix.detach()
78 | data_input = matrix.detach()
79 |
80 | plt.subplot(2, 2, 1)
81 | sns.heatmap(data_channel[0].data.cpu().numpy())
82 | plt.title("1")
83 |
84 | plt.subplot(2, 2, 2)
85 | sns.heatmap(data_input[0].data.cpu().numpy())
86 | plt.title("2")
87 |
88 | plt.subplot(2, 2, 3)
89 | sns.heatmap(data_input[0].data.cpu().numpy())
90 | plt.title("3")
91 |
92 | plt.subplot(2, 2, 4)
93 | sns.heatmap(data_input[0].data.cpu().numpy())
94 | plt.title("4")
95 |
96 | plt.suptitle("JapaneseVowels Attention Heat Map", fontsize='x-large', fontweight='bold')
97 | # plt.savefig('result_figure/JapaneseVowels Attention Heat Map.png')
98 | plt.show()
99 |
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/utils/kmeans.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | '''
4 | @author: wepon, http://2hwp.com
5 | Reference:
6 | Book: <>
7 | Software: sklearn.cluster.KMeans
8 | '''
9 | import numpy as np
10 |
11 |
12 | class KMeans(object):
13 | """
14 | - 参数
15 | n_clusters:
16 | 聚类个数,即k
17 | initCent:
18 | 质心初始化方式,可选"random"或指定一个具体的array,默认random,即随机初始化
19 | max_iter:
20 | 最大迭代次数
21 | """
22 |
23 | def __init__(self, n_clusters=5, initCent='random', max_iter=300):
24 | if hasattr(initCent, '__array__'):
25 | n_clusters = initCent.shape[0]
26 | self.centroids = np.asarray(initCent, dtype=np.float)
27 | else:
28 | self.centroids = None
29 |
30 | self.n_clusters = n_clusters
31 | self.max_iter = max_iter
32 | self.initCent = initCent
33 | self.clusterAssment = None
34 | self.labels = None
35 | self.sse = None
36 |
37 | # 计算两点的欧式距离
38 |
39 | def _distEclud(self, vecA, vecB):
40 | return np.linalg.norm(vecA - vecB)
41 |
42 | # 随机选取k个质心,必须在数据集的边界内
43 | def _randCent(self, X, k):
44 | n = X.shape[1] # 特征维数
45 | centroids = np.empty((k, n)) # k*n的矩阵,用于存储质心
46 | for j in range(n): # 产生k个质心,一维一维地随机初始化
47 | minJ = min(X[:, j])
48 | rangeJ = float(max(X[:, j]) - minJ)
49 | centroids[:, j] = (minJ + rangeJ * np.random.rand(k, 1)).flatten()
50 | return centroids
51 |
52 | def fit(self, X):
53 | # 类型检查
54 | if not isinstance(X, np.ndarray):
55 | try:
56 | X = np.asarray(X)
57 | except:
58 | raise TypeError("numpy.ndarray required for X")
59 |
60 | m = X.shape[0] # m代表样本数量
61 | self.clusterAssment = np.empty((m, 2)) # m*2的矩阵,第一列存储样本点所属的族的索引值,
62 | # 第二列存储该点与所属族的质心的平方误差
63 | if self.initCent == 'random':
64 | self.centroids = self._randCent(X, self.n_clusters)
65 |
66 | clusterChanged = True
67 | for _ in range(self.max_iter):
68 | clusterChanged = False
69 | for i in range(m): # 将每个样本点分配到离它最近的质心所属的族
70 | minDist = np.inf;
71 | minIndex = -1
72 | for j in range(self.n_clusters):
73 | distJI = self._distEclud(self.centroids[j, :], X[i, :])
74 | if distJI < minDist:
75 | minDist = distJI;
76 | minIndex = j
77 | if self.clusterAssment[i, 0] != minIndex:
78 | clusterChanged = True
79 | self.clusterAssment[i, :] = minIndex, minDist ** 2
80 |
81 | if not clusterChanged: # 若所有样本点所属的族都不改变,则已收敛,结束迭代
82 | break
83 | for i in range(self.n_clusters): # 更新质心,即将每个族中的点的均值作为质心
84 | ptsInClust = X[np.nonzero(self.clusterAssment[:, 0] == i)[0]] # 取出属于第i个族的所有点
85 | self.centroids[i, :] = np.mean(ptsInClust, axis=0)
86 |
87 | self.labels = self.clusterAssment[:, 0]
88 | self.sse = sum(self.clusterAssment[:, 1])
89 |
90 | def predict(self, X): # 根据聚类结果,预测新输入数据所属的族
91 | # 类型检查
92 | if not isinstance(X, np.ndarray):
93 | try:
94 | X = np.asarray(X)
95 | except:
96 | raise TypeError("numpy.ndarray required for X")
97 |
98 | m = X.shape[0] # m代表样本数量
99 | preds = np.empty((m,))
100 | for i in range(m): # 将每个样本点分配到离它最近的质心所属的族
101 | minDist = np.inf
102 | for j in range(self.n_clusters):
103 | distJI = self._distEclud(self.centroids[j, :], X[i, :])
104 | if distJI < minDist:
105 | minDist = distJI
106 | preds[i] = j
107 | return preds
108 |
109 |
110 | class biKMeans(object):
111 | def __init__(self, n_clusters=5):
112 | self.n_clusters = n_clusters
113 | self.centroids = None
114 | self.clusterAssment = None
115 | self.labels = None
116 | self.sse = None
117 |
118 | # 计算两点的欧式距离
119 | def _distEclud(self, vecA, vecB):
120 | return np.linalg.norm(vecA - vecB)
121 |
122 | def fit(self, X):
123 | m = X.shape[0]
124 | self.clusterAssment = np.zeros((m, 2))
125 | centroid0 = np.mean(X, axis=0).tolist()
126 | centList = [centroid0]
127 | for j in range(m): # 计算每个样本点与质心之间初始的平方误差
128 | self.clusterAssment[j, 1] = self._distEclud(np.asarray(centroid0), X[j, :]) ** 2
129 |
130 | while (len(centList) < self.n_clusters):
131 | lowestSSE = np.inf
132 | for i in range(len(centList)): # 尝试划分每一族,选取使得误差最小的那个族进行划分
133 | ptsInCurrCluster = X[np.nonzero(self.clusterAssment[:, 0] == i)[0], :]
134 | clf = KMeans(n_clusters=2)
135 | clf.fit(ptsInCurrCluster)
136 | centroidMat, splitClustAss = clf.centroids, clf.clusterAssment # 划分该族后,所得到的质心、分配结果及误差矩阵
137 | sseSplit = sum(splitClustAss[:, 1])
138 | sseNotSplit = sum(self.clusterAssment[np.nonzero(self.clusterAssment[:, 0] != i)[0], 1])
139 | if (sseSplit + sseNotSplit) < lowestSSE:
140 | bestCentToSplit = i
141 | bestNewCents = centroidMat
142 | bestClustAss = splitClustAss.copy()
143 | lowestSSE = sseSplit + sseNotSplit
144 | # 该族被划分成两个子族后,其中一个子族的索引变为原族的索引,另一个子族的索引变为len(centList),然后存入centList
145 | bestClustAss[np.nonzero(bestClustAss[:, 0] == 1)[0], 0] = len(centList)
146 | bestClustAss[np.nonzero(bestClustAss[:, 0] == 0)[0], 0] = bestCentToSplit
147 | centList[bestCentToSplit] = bestNewCents[0, :].tolist()
148 | centList.append(bestNewCents[1, :].tolist())
149 | self.clusterAssment[np.nonzero(self.clusterAssment[:, 0] == bestCentToSplit)[0], :] = bestClustAss
150 |
151 | self.labels = self.clusterAssment[:, 0]
152 | self.sse = sum(self.clusterAssment[:, 1])
153 | self.centroids = np.asarray(centList)
154 |
155 | def predict(self, X): # 根据聚类结果,预测新输入数据所属的族
156 | # 类型检查
157 | if not isinstance(X, np.ndarray):
158 | try:
159 | X = np.asarray(X)
160 | except:
161 | raise TypeError("numpy.ndarray required for X")
162 |
163 | m = X.shape[0] # m代表样本数量
164 | preds = np.empty((m,))
165 | for i in range(m): # 将每个样本点分配到离它最近的质心所属的族
166 | minDist = np.inf
167 | for j in range(self.n_clusters):
168 | distJI = self._distEclud(self.centroids[j, :], X[i, :])
169 | if distJI < minDist:
170 | minDist = distJI
171 | preds[i] = j
172 | return preds
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/utils/random_seed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def setup_seed(seed):
6 | torch.manual_seed(seed)
7 | torch.cuda.manual_seed_all(seed)
8 | np.random.seed(seed)
9 | torch.backends.cudnn.deterministic = True
10 |
11 |
12 |
--------------------------------------------------------------------------------
/Gated Transformer 论文IJCAI版/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from matplotlib.font_manager import FontProperties as fp # 1、引入FontProperties
3 | import math
4 |
5 | def result_visualization(loss_list: list,
6 | correct_on_test: list,
7 | correct_on_train: list,
8 | test_interval: int,
9 | d_model: int,
10 | q: int,
11 | v: int,
12 | h: int,
13 | N: int,
14 | dropout: float,
15 | DATA_LEN: int,
16 | BATCH_SIZE: int,
17 | time_cost: float,
18 | EPOCH: int,
19 | draw_key: int,
20 | reslut_figure_path: str,
21 | optimizer_name: str,
22 | file_name: str,
23 | LR: float,
24 | pe: bool,
25 | mask: bool):
26 | my_font = fp(fname=r"font/simsun.ttc") # 2、设置字体路径
27 |
28 | # 设置风格
29 | plt.style.use('seaborn')
30 |
31 | fig = plt.figure() # 创建基础图
32 | ax1 = fig.add_subplot(311) # 创建两个子图
33 | ax2 = fig.add_subplot(313)
34 |
35 | ax1.plot(loss_list) # 添加折线
36 | ax2.plot(correct_on_test, color='red', label='on Test Dataset')
37 | ax2.plot(correct_on_train, color='blue', label='on Train Dataset')
38 |
39 | # 设置坐标轴标签 和 图的标题
40 | ax1.set_xlabel('epoch')
41 | ax1.set_ylabel('loss')
42 | ax2.set_xlabel(f'epoch/{test_interval}')
43 | ax2.set_ylabel('correct')
44 | ax1.set_title('LOSS')
45 | ax2.set_title('CORRECT')
46 |
47 | plt.legend(loc='best')
48 |
49 | # 设置文本
50 | fig.text(x=0.13, y=0.4, s=f'最小loss:{min(loss_list)}' ' '
51 | f'最小loss对应的epoch数:{math.ceil((loss_list.index(min(loss_list)) + 1) / math.ceil((DATA_LEN / BATCH_SIZE)))}' ' '
52 | f'最后一轮loss:{loss_list[-1]}' '\n'
53 | f'最大correct:测试集:{max(correct_on_test)}% 训练集:{max(correct_on_train)}%' ' '
54 | f'最大correct对应的已训练epoch数:{(correct_on_test.index(max(correct_on_test)) + 1) * test_interval}' ' '
55 | f'最后一轮correct:{correct_on_test[-1]}%' '\n'
56 | f'd_model={d_model} q={q} v={v} h={h} N={N} drop_out={dropout}' '\n'
57 | f'共耗时{round(time_cost, 2)}分钟', FontProperties=my_font)
58 |
59 | # 保存结果图 测试不保存图(epoch少于draw_key)
60 | if EPOCH >= draw_key:
61 | plt.savefig(
62 | f'{reslut_figure_path}/{file_name} {max(correct_on_test)}% {optimizer_name} epoch={EPOCH} batch={BATCH_SIZE} lr={LR} pe={pe} mask={mask} [{d_model},{q},{v},{h},{N},{dropout}].png')
63 |
64 | # 展示图
65 | plt.show()
66 |
67 | print('正确率列表', correct_on_test)
68 |
69 | print(f'最小loss:{min(loss_list)}\r\n'
70 | f'最小loss对应的epoch数:{math.ceil((loss_list.index(min(loss_list)) + 1) / math.ceil((DATA_LEN / BATCH_SIZE)))}\r\n'
71 | f'最后一轮loss:{loss_list[-1]}\r\n')
72 |
73 | print(f'最大correct:测试集:{max(correct_on_test)}\t 训练集:{max(correct_on_train)}\r\n'
74 | f'最correct对应的已训练epoch数:{(correct_on_test.index(max(correct_on_test)) + 1) * test_interval}\r\n'
75 | f'最后一轮correct:{correct_on_test[-1]}')
76 |
77 | print(f'共耗时{round(time_cost, 2)}分钟')
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Gated-Transformer-on-MTS
2 | 基于Pytorch,使用改良的Transformer模型应用于多维时间序列的分类任务上
3 |
4 | ## 实验结果
5 | 对比模型选择 Fully Convolutional Networks (FCN) and Residual Net works (ResNet)
6 |
7 |
8 | DataSet|MLP|FCN|ResNet|Encoder|MCNN|t-LeNet|MCDCNN|Time-CNN|TWIESN|Gated Transformer|
9 | -------|---|---|------|-------|----|-------|------|--------|------|-----------------|
10 | ArabicDigits|96.9|99.4|**99.6**|98.1|10.0|10.0|95.9|95.8|85.3|98.8|
11 | AUSLAN|93.3|**97.5**|97.4|93.8|1.1|1.1|85.4|72.6|72.4|**97.5**|
12 | CharacterTrajectories|96.9|**99.0**|**99.0**|97.1|5.4|6.7|93.8|96.0|92.0|97.0|
13 | CMUsubject16|60.0|**100**|99.7|98.3|53.1|51.0|51.4|97.6|89.3|**100**|
14 | ECG|74.8|87.2|86.7|87.2|67.0|67.0|50.0|84.1|73.7|**91.0**|
15 | JapaneseVowels|97.6||**99.3**|99.2|97.6|9.2|23.8|94.4|95.6|96.5|98.7|
16 | Libras|78.0|**96.4**|95.4|78.3|6.7|6.7|65.1|63.7|79.4|88.9|
17 | UWave|90.1|**93.4**|92.6|90.8|12.5|12.5|84.5|8.9|75.4|91.0|
18 | KickvsPunch|61.0|54.0|51.0|61.0|54.0|50.0|56.0|62.0|67.0|**90.0**|
19 | NetFlow|55.0|89.1|62.7|77.7|77.9|72.3|63.0|89.0|94.5|**100**|
20 | PEMS|-|-|-|-|-|-|-|-|-|93.6|
21 | Wafer|89.4|98.2|98.9|98.6|89.4|89.4|65.8|94.8|94.9|**99.1**|
22 | WalkvsRun|70.0|**100**|**100**|**100**|75.0|60.0|45.0|**100.0**|94.4|**100**|
23 |
24 | ## 实验环境
25 | 环境|描述|
26 | ---|---------|
27 | 语言|Python3.7|
28 | 框架|Pytorch1.6|
29 | IDE|Pycharm and Colab|
30 | 设备|CPU and GPU|
31 |
32 | ## 数据集
33 | 多元时间序列数据集, 文件为.mat格式,训练集与测试集在一个文件中,且预先定义为了测试集数据,测试集标签,训练集数据与训练集标签。
34 | 数据集下载使用百度云盘,连接如下:
35 | 链接:https://pan.baidu.com/s/1u2HN6tfygcQvzuEK5XBa2A
36 | 提取码:dxq6
37 | Google drive link:https://drive.google.com/drive/folders/1QFadJOmbOLWMjLrcebZQR_w2fBX7x0Vm?usp=share_link
38 |
39 | UEA and UCR dataset:http://www.timeseriesclassification.com/index.php
40 |
41 | ---
42 |
43 | 数据集维度描述
44 | DataSet|Number of Classes|Size of training Set|Size of testing Set|Max Time series Length|Channel|
45 | -------|-----------------|--------------------|-------------------|----------------------|-------|
46 | ArabicDigits|10|6600|2200|93|13|
47 | AUSLAN|95|1140|1425|136|22|
48 | CharacterTrajectories|20|300|2558|205|3|
49 | CMUsubject16|2|29|29|580|62|
50 | ECG|2|100|100|152|2|
51 | JapaneseVowels|9|270|370|29|12|
52 | Libras|15|180|180|45|2|
53 | UWave|8|200|4278|315|3|
54 | KickvsPunch|2|16|10|841|62|
55 | NetFlow|2|803|534|997|4|
56 | PEMS|7|267|173|144|963|
57 | Wafer|2|298|896|198|6|
58 | WalkvsRun|2|28|16|1918|62|
59 |
60 | ## 数据预处理
61 | 详细数据集处理过程参看 dataset_process.py文件。
62 | - 创建torch.utils.data.Dataset对象,在类中对数据集进行处理,其成员变量定义的有训练集数据,训练集标签,测试集数据,测试集标签等。创建torch.utils.data.DataLoader对象,生成训练过程中的mini-batch与数据集的随机shuffle
63 | - 数据集中不同样本的Time series Length不同,处理时使用所有样本中(测试集与训练集中)**最长**的时间步作为Time series Length,使用**0**进行填充。
64 | - 数据集处理过程中保存未添加Padding的训练集数据与测试集数据,还有测试集中最长时间步的样本列表以供探索模型使用。
65 | - NetFlow数据集中标签为**1和13**,在使用此数据集时要对返回的标签值进行处理。
66 |
67 | ## 模型描述
68 |
69 |
70 |
71 | - 仅使用Encoder:由于是分类任务,模型删去传统Transformer中的decoder,**仅使用Encoder**进行分类
72 | - Two Tower:在不同的Step之间或者不同Channel之间,显然存在着诸多联系,传统Transformer使用Attentino机制用来关注不同的step或channel之间的相关程度,但仅选择一个进行计算。不同于CNN模型处理时间序列,它可以使用二维卷积核同时关注step-wise和channel-wise,在这里我们使**双塔**模型,即同时计算step-wise Attention和channel-wise Attention。
73 | - Gate机制:对于不同的数据集,不同的Attention机制有好有坏,对于双塔的特征提取的结果,简单的方法,是对两个塔的输出尽心简单的拼接,不过在这里,我们使用模型学习两个权重值,为每个塔的输出进行权重的分配,公式如下。
74 | `h = W · Concat(C, S) + b`
75 | `g1, g2 = Softmax(h)`
76 | `y = Concat(C · g1, S · g2)`
77 | - 在step-wise,模型如传统Transformer一样,添加位置编码与mask机制,而在channel-wise,模型舍弃位置编码与mask,因为对于没有时间特性的channel之间,这两个机制没有实际的意义。
78 |
79 | ## 超参描述
80 | 超参|描述|
81 | ----|---|
82 | d_model|模型处理的为时间序列而非自然语言,所以省略了NLP中对词语的编码,仅使用一个线性层映射成d_model维的稠密向量,此外,d_model保证了在每个模块衔接的地方的维度相同|
83 | d_hidden|Position-wise FeedForword 中隐藏层的维度|
84 | d_input|时间序列长度,其实是一个数据集中最长时间步的维度 **固定**的,直接由数据集预处理决定|
85 | d_channel|多元时间序列的时间通道数,即是几维的时间序列 **固定**的,直接由数据集预处理决定|
86 | d_output|分类类别数 **固定**的,直接由数据集预处理决定|
87 | q,v|Multi-Head Attention中线性层映射维度|
88 | h|Multi-Head Attention中头的数量|
89 | N|Encoder栈中Encoder的数量|
90 | dropout|随机失活|
91 | EPOCH|训练迭代次数|
92 | BATCH_SIZE|mini-batch size|
93 | LR|学习率 定义为1e-4|
94 | optimizer_name|优化器选择 建议**Adagrad**和Adam|
95 |
96 | ## 文件描述
97 | 文件名称|描述|
98 | -------|----|
99 | dataset_process|数据集处理|
100 | font|存储字体,用于结果图中的文字|
101 | gather_figure|聚类结果图|
102 | heatmap_figure_in_test|测试模型时绘制的score矩阵的热力图|
103 | module|模型的各个模块|
104 | mytest|各种测试代码|
105 | reslut_figure|准确率结果图|
106 | saved_model|保存的pkl文件|
107 | utils|工具类文件|
108 | run.py|训练模型|
109 | run_with_saved_model.py|使用训练好的模型(保存为pkl文件)测试结果|
110 |
111 | ## utils工具描述
112 | 简单介绍几个
113 | - random_seed:用于设置**随机种子**,使每一次的实验结果可复现。
114 | - heatMap.py:用于绘制双塔的score矩阵的**热力图**,用来分析channel与channel之间或者step与step之间的相关程度,用于比较的还有**DTW**矩阵和欧氏距离矩阵,用来分析决定权重分配的因素。
115 | - draw_line:用于绘制折线图,一般需要根据需要自定义新的函数进行绘制。
116 | - visualization:用于绘制训练模型的loss变化曲线和accuracy变化曲线,判断是否收敛与过拟合。
117 | - TSNE:**降维聚类算法**并绘制聚类图,用来评估模型特征提取的效果或者时间序列之间的相似性。
118 |
119 | ## Tips
120 | - .pkl文件需要先训练,并在训练结束时进行保存(设置参数为True),由于github对文件大小的限制,上传文件中不包含训练好的.pkl文件。
121 | - .pkl文件使用pycharm上的1.6版本的pytorch和colab上1.7的pytorch保存,若想load模型直接进行测试,需要测试使用的pytorch版本尽可能高于等于**1.6版本**。
122 | - 根目录文件如saved_model,reslut_figure为保存的默认路径,请勿删除或者修改名称,除非直接在源代码中对路径进行修改。
123 | - 请使用百度云盘提供的数据集,不同的MTS数据集文件格式不同,本数据集处理的是.mat文件。
124 | - utils中的工具类,在绘制彩色曲线和聚类图时,对于图中颜色的划分,由于需求不能泛化,请在函数中自行编写代码定义。
125 | - save model保存的.pkl文件在迭代过程中不断更新,在最后保存最高准确率的模型并命名,命名格式请勿修改,因为在run_with_saved_model.py中,对文件命名中的信息会加以利用,若干绘图结果的命名也会参考其中的信息。
126 | - 优先选择GPU,没有则使用CPU。
127 |
128 | ## 参考
129 | ```
130 | [Wang et al., 2017] Z. Wang, W. Yan, and T. Oates. Time series classification from scratch with deep neural networks:A strong baseline. In 2017 International Joint Conference on Neural Networks (IJCNN), pages 1578–1585, 2017.
131 | ```
132 |
133 | ## 本人学识浅薄,代码和文字若有不当之处欢迎批评与指正!
134 | ## 联系方式:masiyuan007@qq.com
135 |
--------------------------------------------------------------------------------