├── Gated Transformer 论文IJCAI版 ├── .idea │ ├── .gitignore │ ├── ADL_Transformer.iml │ ├── inspectionProfiles │ │ └── profiles_settings.xml │ ├── misc.xml │ ├── modules.xml │ └── vcs.xml ├── Gated Transformer 架构图.png ├── dataset_process │ ├── __pycache__ │ │ └── dataset_process.cpython-37.pyc │ └── dataset_process.py ├── font │ └── simsun.ttc ├── gather_figure │ ├── AUSLAN │ │ ├── AUSLAN channel_gather.jpg │ │ └── AUSLAN input_gather.jpg │ ├── ECG │ │ ├── ECG all_sample_gather.jpg │ │ ├── ECG channel_gather.jpg │ │ └── ECG input_gather.jpg │ └── JapaneseVowels │ │ ├── JapaneseVowels Sample_gather.jpg │ │ ├── JapaneseVowels channel_gather.jpg │ │ └── JapaneseVowels input_gather.jpg ├── heatmap_figure_in_test │ └── JapaneseVowels │ │ └── JapaneseVowels accuracy=98.11 0.jpg ├── module │ ├── __pycache__ │ │ ├── encoder.cpython-37.pyc │ │ ├── feedForward.cpython-37.pyc │ │ ├── loss.cpython-37.pyc │ │ ├── multiHeadAttention.cpython-37.pyc │ │ └── transformer.cpython-37.pyc │ ├── encoder.py │ ├── feedForward.py │ ├── loss.py │ ├── multiHeadAttention.py │ └── transformer.py ├── mytest │ ├── DTW_test.py │ ├── HeatMap.py │ ├── HeatMap_DTW.py │ ├── __pycache__ │ │ └── HeatMap_DTW.cpython-37.pyc │ └── kmeans_test.py ├── result_figure │ ├── CharacterTrajectories 96.6% Adam epoch=100 batch=4 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png │ ├── ECG 86.0% Adam epoch=10 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png │ ├── ECG 86.0% Adam epoch=20 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png │ ├── ECG 87.0% Adam epoch=50 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png │ └── JapaneseVowels 98.11% Adagrad epoch=100 batch=3 lr=0.0001 pe=True mask=True [512,8,8,8,8,0.2].png ├── run.py ├── run_with_saved_model.py └── utils │ ├── TSNE.py │ ├── __pycache__ │ ├── TSNE.cpython-37.pyc │ ├── colorful_line.cpython-37.pyc │ ├── draw_line.cpython-37.pyc │ ├── heatMap.cpython-37.pyc │ ├── random_seed.cpython-37.pyc │ └── visualization.cpython-37.pyc │ ├── colorful_line.py │ ├── draw_line.py │ ├── heatMap.py │ ├── kmeans.py │ ├── random_seed.py │ └── visualization.py └── README.md /Gated Transformer 论文IJCAI版/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/.idea/ADL_Transformer.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/Gated Transformer 架构图.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/Gated Transformer 架构图.png -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/dataset_process/__pycache__/dataset_process.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/dataset_process/__pycache__/dataset_process.cpython-37.pyc -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/dataset_process/dataset_process.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | 4 | from scipy.io import loadmat 5 | 6 | 7 | class MyDataset(Dataset): 8 | def __init__(self, 9 | path: str, 10 | dataset: str): 11 | """ 12 | 训练数据集与测试数据集的Dataset对象 13 | :param path: 数据集路径 14 | :param dataset: 区分是获得训练集还是测试集 15 | """ 16 | super(MyDataset, self).__init__() 17 | self.dataset = dataset # 选择获取测试集还是训练集 18 | self.train_len, \ 19 | self.test_len, \ 20 | self.input_len, \ 21 | self.channel_len, \ 22 | self.output_len, \ 23 | self.train_dataset, \ 24 | self.train_label, \ 25 | self.test_dataset, \ 26 | self.test_label, \ 27 | self.max_length_sample_inTest, \ 28 | self.train_dataset_with_no_paddding = self.pre_option(path) 29 | 30 | def __getitem__(self, index): 31 | if self.dataset == 'train': 32 | return self.train_dataset[index], self.train_label[index] - 1 33 | elif self.dataset == 'test': 34 | return self.test_dataset[index], self.test_label[index] - 1 35 | 36 | def __len__(self): 37 | if self.dataset == 'train': 38 | return self.train_len 39 | elif self.dataset == 'test': 40 | return self.test_len 41 | 42 | # 数据预处理 43 | def pre_option(self, path: str): 44 | """ 45 | 数据预处理 由于每个样本的时间步维度不同,在此使用最长的时间步作为时间步的维度,使用0进行填充 46 | :param path: 数据集路径 47 | :return: 训练集样本数量,测试集样本数量,时间步维度,通道数,分类数,训练集数据,训练集标签,测试集数据,测试集标签,测试集中时间步最长的样本列表,没有padding的训练集数据 48 | """ 49 | m = loadmat(path) 50 | 51 | # m中是一个字典 有4个key 其中最后一个键值对存储的是数据 52 | x1, x2, x3, x4 = m 53 | data = m[x4] 54 | 55 | data00 = data[0][0] 56 | # print('data00.shape', data00.shape) # () data00才到达数据的维度 57 | 58 | index_train = str(data.dtype).find('train\'') 59 | index_trainlabels = str(data.dtype).find('trainlabels') 60 | index_test = str(data.dtype).find('test\'') 61 | index_testlabels = str(data.dtype).find('testlabels') 62 | list = [index_test, index_train, index_testlabels, index_trainlabels] 63 | list = sorted(list) 64 | index_train = list.index(index_train) 65 | index_trainlabels = list.index(index_trainlabels) 66 | index_test = list.index(index_test) 67 | index_testlabels = list.index(index_testlabels) 68 | 69 | # [('trainlabels', 'O'), ('train', 'O'), ('testlabels', 'O'), ('test', 'O')] O 表示数据类型为 numpy.object 70 | train_label = data00[index_trainlabels] 71 | train_data = data00[index_train] 72 | test_label = data00[index_testlabels] 73 | test_data = data00[index_test] 74 | 75 | train_label = train_label.squeeze() 76 | train_data = train_data.squeeze() 77 | test_label = test_label.squeeze() 78 | test_data = test_data.squeeze() 79 | 80 | train_len = train_data.shape[0] 81 | test_len = test_data.shape[0] 82 | output_len = len(tuple(set(train_label))) 83 | 84 | # 时间步最大值 85 | max_lenth = 0 # 93 86 | for item in train_data: 87 | item = torch.as_tensor(item).float() 88 | if item.shape[1] > max_lenth: 89 | max_lenth = item.shape[1] 90 | # max_length_index = train_data.tolist().index(item.tolist()) 91 | 92 | for item in test_data: 93 | item = torch.as_tensor(item).float() 94 | if item.shape[1] > max_lenth: 95 | max_lenth = item.shape[1] 96 | 97 | # 填充Padding 使用0进行填充 98 | # train_data, test_data为numpy.object 类型,不能直接对里面的numpy.ndarray进行处理 99 | train_dataset_with_no_paddding = [] 100 | test_dataset_with_no_paddding = [] 101 | train_dataset = [] 102 | test_dataset = [] 103 | max_length_sample_inTest = [] 104 | for x1 in train_data: 105 | train_dataset_with_no_paddding.append(x1.transpose(-1, -2).tolist()) 106 | x1 = torch.as_tensor(x1).float() 107 | if x1.shape[1] != max_lenth: 108 | padding = torch.zeros(x1.shape[0], max_lenth - x1.shape[1]) 109 | x1 = torch.cat((x1, padding), dim=1) 110 | train_dataset.append(x1) 111 | 112 | for index, x2 in enumerate(test_data): 113 | test_dataset_with_no_paddding.append(x2.transpose(-1, -2).tolist()) 114 | x2 = torch.as_tensor(x2).float() 115 | if x2.shape[1] != max_lenth: 116 | padding = torch.zeros(x2.shape[0], max_lenth - x2.shape[1]) 117 | x2 = torch.cat((x2, padding), dim=1) 118 | else: 119 | max_length_sample_inTest.append(x2.transpose(-1, -2)) 120 | test_dataset.append(x2) 121 | 122 | # 最后维度 [数据条数,时间步数最大值,时间序列维度] 123 | # train_dataset_with_no_paddding = torch.stack(train_dataset_with_no_paddding, dim=0).permute(0, 2, 1) 124 | # test_dataset_with_no_paddding = torch.stack(test_dataset_with_no_paddding, dim=0).permute(0, 2, 1) 125 | train_dataset = torch.stack(train_dataset, dim=0).permute(0, 2, 1) 126 | test_dataset = torch.stack(test_dataset, dim=0).permute(0, 2, 1) 127 | train_label = torch.Tensor(train_label) 128 | test_label = torch.Tensor(test_label) 129 | channel = test_dataset[0].shape[-1] 130 | input = test_dataset[0].shape[-2] 131 | 132 | return train_len, test_len, input, channel, output_len, train_dataset, train_label, test_dataset, test_label, max_length_sample_inTest, train_dataset_with_no_paddding -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/font/simsun.ttc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/font/simsun.ttc -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/gather_figure/AUSLAN/AUSLAN channel_gather.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/AUSLAN/AUSLAN channel_gather.jpg -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/gather_figure/AUSLAN/AUSLAN input_gather.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/AUSLAN/AUSLAN input_gather.jpg -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/gather_figure/ECG/ECG all_sample_gather.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/ECG/ECG all_sample_gather.jpg -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/gather_figure/ECG/ECG channel_gather.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/ECG/ECG channel_gather.jpg -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/gather_figure/ECG/ECG input_gather.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/ECG/ECG input_gather.jpg -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/gather_figure/JapaneseVowels/JapaneseVowels Sample_gather.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/JapaneseVowels/JapaneseVowels Sample_gather.jpg -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/gather_figure/JapaneseVowels/JapaneseVowels channel_gather.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/JapaneseVowels/JapaneseVowels channel_gather.jpg -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/gather_figure/JapaneseVowels/JapaneseVowels input_gather.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/gather_figure/JapaneseVowels/JapaneseVowels input_gather.jpg -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/heatmap_figure_in_test/JapaneseVowels/JapaneseVowels accuracy=98.11 0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/heatmap_figure_in_test/JapaneseVowels/JapaneseVowels accuracy=98.11 0.jpg -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/module/__pycache__/encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/module/__pycache__/encoder.cpython-37.pyc -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/module/__pycache__/feedForward.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/module/__pycache__/feedForward.cpython-37.pyc -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/module/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/module/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/module/__pycache__/multiHeadAttention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/module/__pycache__/multiHeadAttention.cpython-37.pyc -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/module/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/module/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/module/encoder.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | import torch 3 | 4 | from module.feedForward import FeedForward 5 | from module.multiHeadAttention import MultiHeadAttention 6 | 7 | class Encoder(Module): 8 | def __init__(self, 9 | d_model: int, 10 | d_hidden: int, 11 | q: int, 12 | v: int, 13 | h: int, 14 | device: str, 15 | mask: bool = False, 16 | dropout: float = 0.1): 17 | super(Encoder, self).__init__() 18 | 19 | self.MHA = MultiHeadAttention(d_model=d_model, q=q, v=v, h=h, mask=mask, device=device, dropout=dropout) 20 | self.feedforward = FeedForward(d_model=d_model, d_hidden=d_hidden) 21 | self.dropout = torch.nn.Dropout(p=dropout) 22 | self.layerNormal_1 = torch.nn.LayerNorm(d_model) 23 | self.layerNormal_2 = torch.nn.LayerNorm(d_model) 24 | 25 | def forward(self, x, stage): 26 | 27 | residual = x 28 | x, score = self.MHA(x, stage) 29 | x = self.dropout(x) 30 | x = self.layerNormal_1(x + residual) 31 | 32 | residual = x 33 | x = self.feedforward(x) 34 | x = self.dropout(x) 35 | x = self.layerNormal_2(x + residual) 36 | 37 | return x, score 38 | -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/module/feedForward.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | class FeedForward(Module): 6 | def __init__(self, 7 | d_model: int, 8 | d_hidden: int = 512): 9 | super(FeedForward, self).__init__() 10 | 11 | self.linear_1 = torch.nn.Linear(d_model, d_hidden) 12 | self.linear_2 = torch.nn.Linear(d_hidden, d_model) 13 | 14 | def forward(self, x): 15 | 16 | x = self.linear_1(x) 17 | x = F.relu(x) 18 | x = self.linear_2(x) 19 | 20 | return x -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/module/loss.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | import torch 3 | from torch.nn import CrossEntropyLoss 4 | 5 | class Myloss(Module): 6 | def __init__(self): 7 | super(Myloss, self).__init__() 8 | self.loss_function = CrossEntropyLoss() 9 | 10 | def forward(self, y_pre, y_true): 11 | y_true = y_true.long() 12 | loss = self.loss_function(y_pre, y_true) 13 | 14 | return loss -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/module/multiHeadAttention.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | import torch 3 | import math 4 | import torch.nn.functional as F 5 | 6 | 7 | class MultiHeadAttention(Module): 8 | def __init__(self, 9 | d_model: int, 10 | q: int, 11 | v: int, 12 | h: int, 13 | device: str, 14 | mask: bool=False, 15 | dropout: float = 0.1): 16 | super(MultiHeadAttention, self).__init__() 17 | 18 | self.W_q = torch.nn.Linear(d_model, q * h) 19 | self.W_k = torch.nn.Linear(d_model, q * h) 20 | self.W_v = torch.nn.Linear(d_model, v * h) 21 | 22 | self.W_o = torch.nn.Linear(v * h, d_model) 23 | 24 | self.device = device 25 | self._h = h 26 | self._q = q 27 | 28 | self.mask = mask 29 | self.dropout = torch.nn.Dropout(p=dropout) 30 | self.score = None 31 | 32 | def forward(self, x, stage): 33 | Q = torch.cat(self.W_q(x).chunk(self._h, dim=-1), dim=0) 34 | K = torch.cat(self.W_k(x).chunk(self._h, dim=-1), dim=0) 35 | V = torch.cat(self.W_v(x).chunk(self._h, dim=-1), dim=0) 36 | 37 | score = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self._q) 38 | self.score = score 39 | 40 | if self.mask and stage == 'train': 41 | mask = torch.ones_like(score[0]) 42 | mask = torch.tril(mask, diagonal=0) 43 | score = torch.where(mask > 0, score, torch.Tensor([-2**32+1]).expand_as(score[0]).to(self.device)) 44 | 45 | score = F.softmax(score, dim=-1) 46 | 47 | attention = torch.matmul(score, V) 48 | 49 | attention_heads = torch.cat(attention.chunk(self._h, dim=0), dim=-1) 50 | 51 | self_attention = self.W_o(attention_heads) 52 | 53 | return self_attention, self.score -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/module/transformer.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | import torch 3 | from torch.nn import ModuleList 4 | from module.encoder import Encoder 5 | import math 6 | import torch.nn.functional as F 7 | 8 | 9 | class Transformer(Module): 10 | def __init__(self, 11 | d_model: int, 12 | d_input: int, 13 | d_channel: int, 14 | d_output: int, 15 | d_hidden: int, 16 | q: int, 17 | v: int, 18 | h: int, 19 | N: int, 20 | device: str, 21 | dropout: float = 0.1, 22 | pe: bool = False, 23 | mask: bool = False): 24 | super(Transformer, self).__init__() 25 | 26 | self.encoder_list_1 = ModuleList([Encoder(d_model=d_model, 27 | d_hidden=d_hidden, 28 | q=q, 29 | v=v, 30 | h=h, 31 | mask=mask, 32 | dropout=dropout, 33 | device=device) for _ in range(N)]) 34 | 35 | self.encoder_list_2 = ModuleList([Encoder(d_model=d_model, 36 | d_hidden=d_hidden, 37 | q=q, 38 | v=v, 39 | h=h, 40 | dropout=dropout, 41 | device=device) for _ in range(N)]) 42 | 43 | self.embedding_channel = torch.nn.Linear(d_channel, d_model) 44 | self.embedding_input = torch.nn.Linear(d_input, d_model) 45 | 46 | self.gate = torch.nn.Linear(d_model * d_input + d_model * d_channel, 2) 47 | self.output_linear = torch.nn.Linear(d_model * d_input + d_model * d_channel, d_output) 48 | 49 | self.pe = pe 50 | self._d_input = d_input 51 | self._d_model = d_model 52 | 53 | def forward(self, x, stage): 54 | """ 55 | 前向传播 56 | :param x: 输入 57 | :param stage: 用于描述此时是训练集的训练过程还是测试集的测试过程 测试过程中均不在加mask机制 58 | :return: 输出,gate之后的二维向量,step-wise encoder中的score矩阵,channel-wise encoder中的score矩阵,step-wise embedding后的三维矩阵,channel-wise embedding后的三维矩阵,gate 59 | """ 60 | # step-wise 61 | # score矩阵为 input, 默认加mask 和 pe 62 | encoding_1 = self.embedding_channel(x) 63 | input_to_gather = encoding_1 64 | 65 | if self.pe: 66 | pe = torch.ones_like(encoding_1[0]) 67 | position = torch.arange(0, self._d_input).unsqueeze(-1) 68 | temp = torch.Tensor(range(0, self._d_model, 2)) 69 | temp = temp * -(math.log(10000) / self._d_model) 70 | temp = torch.exp(temp).unsqueeze(0) 71 | temp = torch.matmul(position.float(), temp) # shape:[input, d_model/2] 72 | pe[:, 0::2] = torch.sin(temp) 73 | pe[:, 1::2] = torch.cos(temp) 74 | 75 | encoding_1 = encoding_1 + pe 76 | 77 | for encoder in self.encoder_list_1: 78 | encoding_1, score_input = encoder(encoding_1, stage) 79 | 80 | # channel-wise 81 | # score矩阵为channel 默认不加mask和pe 82 | encoding_2 = self.embedding_input(x.transpose(-1, -2)) 83 | channel_to_gather = encoding_2 84 | 85 | for encoder in self.encoder_list_2: 86 | encoding_2, score_channel = encoder(encoding_2, stage) 87 | 88 | # 三维变二维 89 | encoding_1 = encoding_1.reshape(encoding_1.shape[0], -1) 90 | encoding_2 = encoding_2.reshape(encoding_2.shape[0], -1) 91 | 92 | # gate 93 | gate = F.softmax(self.gate(torch.cat([encoding_1, encoding_2], dim=-1)), dim=-1) 94 | encoding = torch.cat([encoding_1 * gate[:, 0:1], encoding_2 * gate[:, 1:2]], dim=-1) 95 | 96 | # 输出 97 | output = self.output_linear(encoding) 98 | 99 | return output, encoding, score_input, score_channel, input_to_gather, channel_to_gather, gate 100 | -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/mytest/DTW_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from dtw import dtw 4 | 5 | # x = np.array([2, 0, 1, 1, 2, 4, 2, 1, 2, 0]).reshape(-1, 1) 6 | # y = np.array([1, 1, 2, 4, 2, 1, 2, 0]).reshape(-1, 1) 7 | x = np.array([8, 9, 1, 9, 6, 1, 3, 5]).reshape(-1, 1) 8 | y = np.array([2, 5, 4, 6, 7, 8, 3, 7, 7, 2]).reshape(-1, 1) 9 | 10 | 11 | euclidean_norm = lambda x, y: np.abs(x - y) 12 | 13 | d, cost_matrix, acc_cost_matrix, path = dtw(x, y, dist=euclidean_norm) 14 | 15 | print('d ', d) 16 | print('cost_matrix \r\n', cost_matrix) 17 | print('acc_cost_matrix \r\n', acc_cost_matrix) 18 | print('path ', path) -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/mytest/HeatMap.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | sns.set() 4 | flights_long = sns.load_dataset("flights") 5 | flights = flights_long.pivot("month", "year", "passengers") 6 | # 绘制x-y-z的热力图,比如 年-月-销量 的热力图 7 | f, ax = plt.subplots(figsize=(9, 6)) 8 | #绘制热力图,还要将数值写到热力图上 9 | sns.heatmap(flights, annot=True, fmt="d", ax=ax) 10 | #设置坐标字体方向 11 | label_y = ax.get_yticklabels() 12 | plt.setp(label_y, rotation=360, horizontalalignment='right') 13 | label_x = ax.get_xticklabels() 14 | plt.setp(label_x, rotation=45, horizontalalignment='right') 15 | plt.show() -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/mytest/HeatMap_DTW.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import numpy as np 4 | from dtw import dtw 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | import os 8 | 9 | plt.rcParams['font.sans-serif']=['SimHei'] #显示中文标签 10 | plt.rcParams['axes.unicode_minus']=False #这两行需要手动设 11 | 12 | from scipy.io import loadmat 13 | 14 | from torch.utils.data import Dataset 15 | import torch 16 | 17 | from scipy.io import loadmat 18 | 19 | 20 | class MyDataset(Dataset): 21 | def __init__(self, path, dataset): 22 | super(MyDataset, self).__init__() 23 | self.dataset = dataset 24 | self.train_len, \ 25 | self.test_len, \ 26 | self.input_len, \ 27 | self.channel_len, \ 28 | self.output_len, \ 29 | self.train_dataset, \ 30 | self.train_label, \ 31 | self.test_dataset, \ 32 | self.test_label = self.pre_option(path) 33 | 34 | def __getitem__(self, index): 35 | if self.dataset == 'train': 36 | return self.train_dataset[index], self.train_label[index] - 1 37 | elif self.dataset == 'test': 38 | return self.test_dataset[index], self.test_label[index] - 1 39 | 40 | def __len__(self): 41 | if self.dataset == 'train': 42 | return self.train_len 43 | elif self.dataset == 'test': 44 | return self.test_len 45 | 46 | def pre_option(self, path): 47 | m = loadmat(path) 48 | 49 | # m中是一个字典 有4个key 其中最后一个键值对存储的是数据 50 | x1, x2, x3, x4 = m 51 | data = m[x4] 52 | 53 | data00 = data[0][0] 54 | # print('data00.shape', data00.shape) # () data00才到达数据的维度 55 | 56 | index_train = str(data.dtype).find('train\'') 57 | index_trainlabels = str(data.dtype).find('trainlabels') 58 | index_test = str(data.dtype).find('test\'') 59 | index_testlabels = str(data.dtype).find('testlabels') 60 | list = [index_test, index_train, index_testlabels, index_trainlabels] 61 | list = sorted(list) 62 | index_train = list.index(index_train) 63 | index_trainlabels = list.index(index_trainlabels) 64 | index_test = list.index(index_test) 65 | index_testlabels = list.index(index_testlabels) 66 | 67 | # [('trainlabels', 'O'), ('train', 'O'), ('testlabels', 'O'), ('test', 'O')] O 表示数据类型为 numpy.object 68 | train_label = data00[index_trainlabels] 69 | train_data = data00[index_train] 70 | test_label = data00[index_testlabels] 71 | test_data = data00[index_test] 72 | 73 | train_label = train_label.squeeze() 74 | train_data = train_data.squeeze() 75 | test_label = test_label.squeeze() 76 | test_data = test_data.squeeze() 77 | 78 | train_len = train_data.shape[0] 79 | test_len = test_data.shape[0] 80 | output_len = len(tuple(set(train_label))) 81 | 82 | # 时间步最大值 83 | max_lenth = 0 # 93 84 | for item in train_data: 85 | item = torch.as_tensor(item).float() 86 | if item.shape[1] > max_lenth: 87 | max_lenth = item.shape[1] 88 | 89 | for item in test_data: 90 | item = torch.as_tensor(item).float() 91 | if item.shape[1] > max_lenth: 92 | max_lenth = item.shape[1] 93 | 94 | # train_data, test_data为numpy.object 类型,不能直接对里面的numpy.ndarray进行处理 95 | train_dataset = [] 96 | test_dataset = [] 97 | for x1 in train_data: 98 | x1 = torch.as_tensor(x1).float() 99 | if x1.shape[1] != max_lenth: 100 | padding = torch.zeros(x1.shape[0], max_lenth - x1.shape[1]) 101 | x1 = torch.cat((x1, padding), dim=1) 102 | train_dataset.append(x1) 103 | 104 | for x2 in test_data: 105 | x2 = torch.as_tensor(x2).float() 106 | if x2.shape[1] != max_lenth: 107 | padding = torch.zeros(x2.shape[0], max_lenth - x2.shape[1]) 108 | x2 = torch.cat((x2, padding), dim=1) 109 | test_dataset.append(x2) 110 | 111 | # 最后维度 [数据条数,时间步数最大值,时间序列维度] 112 | train_dataset = torch.stack(train_dataset, dim=0).permute(0, 2, 1) 113 | test_dataset = torch.stack(test_dataset, dim=0).permute(0, 2, 1) 114 | train_label = torch.Tensor(train_label) 115 | test_label = torch.Tensor(test_label) 116 | channel = test_dataset[0].shape[-1] 117 | input = test_dataset[0].shape[-2] 118 | 119 | return train_len, test_len, input, channel, output_len, train_dataset, train_label, test_dataset, test_label 120 | 121 | def heatMap_channel(matrix, file_name, EPOCH): 122 | test_data = matrix[0].detach().numpy() 123 | euclidean_norm = lambda x, y: np.abs(x - y) 124 | matrix_0 = np.ones((test_data.shape[1], test_data.shape[1])) 125 | matrix_1 = np.ones((test_data.shape[1], test_data.shape[1])) # 相差度 126 | for i in range(test_data.shape[1]): 127 | for j in range(test_data.shape[1]): 128 | x = test_data[:, i] 129 | y = test_data[:, j] 130 | d, cost_matrix, acc_cost_matrix, path = dtw(x, y, dist=euclidean_norm) 131 | matrix_0[i, j] = d 132 | matrix_1[i, j] = np.mean((x - y) ** 2) 133 | 134 | sns.set() 135 | # f, ax = plt.subplots(figsize=(9, 6)) 136 | sns.heatmap(matrix_0, annot=True) 137 | plt.title('CHANNEL DTW') 138 | if os.path.exists(f'../heatmap_figure/{file_name}'): 139 | plt.savefig(f'../heatmap_figure/{file_name}/channel DTW EPOCH:{EPOCH}.jpg') 140 | else: 141 | os.makedirs(f'../heatmap_figure/{file_name}') 142 | plt.savefig(f'../heatmap_figure/{file_name}/channel DTW EPOCH:{EPOCH}.jpg') 143 | sns.heatmap(matrix_1, annot=True) 144 | plt.title('CHANNEL difference') 145 | plt.savefig(f'../heatmap_figure/{file_name}/channel difference EPOCH:{EPOCH}.jpg') 146 | 147 | def heatMap_input(matrix, file_name, EPOCH): 148 | test_data = matrix[0].detach().numpy() 149 | euclidean_norm = lambda x, y: np.abs(x - y) 150 | matrix_0 = np.ones((test_data.shape[0], test_data.shape[0])) # DTW 151 | matrix_1 = np.ones((test_data.shape[0], test_data.shape[0])) # 相差度 152 | for i in range(test_data.shape[0]): 153 | for j in range(test_data.shape[0]): 154 | x = test_data[i, :] 155 | y = test_data[j, :] 156 | d, cost_matrix, acc_cost_matrix, path = dtw(x, y, dist=euclidean_norm) 157 | matrix_0[i, j] = d 158 | matrix_1[i, j] = np.mean((x - y) ** 2) 159 | 160 | sns.set() 161 | # f, ax = plt.subplots(figsize=(9, 6)) 162 | sns.heatmap(matrix_0, annot=True) 163 | plt.title('INPUT DTW') 164 | if os.path.exists(f'../heatmap_figure/{file_name}'): 165 | plt.savefig(f'../heatmap_figure/{file_name}/input DTW EPOCH:{EPOCH}.jpg') 166 | else: 167 | os.makedirs(f'../heatmap_figure/{file_name}') 168 | plt.savefig(f'../heatmap_figure/{file_name}/input DTW EPOCH:{EPOCH}.jpg') 169 | sns.heatmap(matrix_1, annot=True) 170 | plt.title('INPUT difference') 171 | plt.savefig(f'../heatmap_figure/{file_name}/input difference EPOCH:{EPOCH}.jpg') 172 | 173 | def heatMap_score(matrix_input, matrix_channel, file_name, EPOCH): 174 | score_input = matrix_input[0].detach().numpy() 175 | score_channel = matrix_channel[0].detach().numpy() 176 | sns.set() 177 | # f, ax = plt.subplots(figsize=(9, 6)) 178 | sns.heatmap(score_input, annot=True) 179 | plt.title('SCORE INPUT') 180 | if os.path.exists(f'../heatmap_figure/{file_name}'): 181 | plt.savefig(f'../heatmap_figure/{file_name}/score input EPOCH:{EPOCH}.jpg') 182 | else: 183 | os.makedirs(f'../heatmap_figure/{file_name}') 184 | plt.savefig(f'../heatmap_figure/{file_name}/ score input EPOCH:{EPOCH}.jpg') 185 | sns.heatmap(score_channel, annot=True) 186 | plt.title('SCORE CHANNEL') 187 | plt.savefig(f'../heatmap_figure/{file_name}/score channel EPOCH:{EPOCH}.jpg') 188 | 189 | 190 | if __name__ == '__main__': 191 | path = 'E:\\PyCharmWorkSpace\\mtsdata\\JapaneseVowels\\JapaneseVowels.mat' # lenth=270 input=29 channel=12 output=9 192 | 193 | dataset = MyDataset(path, 'train') 194 | train_dataset = dataset.train_dataset 195 | print(train_dataset.shape) 196 | test_data = train_dataset[0].numpy() 197 | print(test_data[12, :]) 198 | # step_3 = train_dataset[0, 3, :].numpy() 199 | # step_12 = train_dataset[0, 12, :].numpy() 200 | # step_15 = train_dataset[0, 15, :].numpy() 201 | # step_25 = train_dataset[0, 25, :].numpy() 202 | # step_21 = train_dataset[0, 21, :].numpy() 203 | # step_27 = train_dataset[0, 27, :].numpy() 204 | # 205 | # # print(step_25.shape) 206 | # print(step_15) 207 | # print(step_25) 208 | # print(step_27) 209 | 210 | euclidean_norm = lambda x, y: np.abs(x - y) 211 | 212 | # d1, cost_matrix, acc_cost_matrix, path = dtw(step_3, step_12, dist=euclidean_norm) 213 | # d2, cost_matrix, acc_cost_matrix, path = dtw(step_15, step_25, dist=euclidean_norm) 214 | # d3, cost_matrix, acc_cost_matrix, path = dtw(step_3, step_21, dist=euclidean_norm) 215 | # d4, cost_matrix, acc_cost_matrix, path = dtw(step_15, step_27, dist=euclidean_norm) 216 | # 217 | # print(d1) 218 | # print(d2) 219 | # print(d3) 220 | # print(d4) 221 | 222 | print(test_data.shape) 223 | # matrix = np.ones((test_data.shape[0], test_data.shape[0])) # DTW 224 | # matrix_1 = np.ones((test_data.shape[0], test_data.shape[0])) # 欧氏距离 225 | # matrix_2 = np.ones((test_data.shape[0], test_data.shape[0])) # 相差度 226 | # matrix_3 = np.ones((test_data.shape[0], test_data.shape[0])) # 相差度 227 | matrix = np.ones((test_data.shape[1], test_data.shape[1])) # DTW 228 | matrix_1 = np.ones((test_data.shape[1], test_data.shape[1])) # 欧氏距离 229 | matrix_2 = np.ones((test_data.shape[1], test_data.shape[1])) # 相差度 230 | matrix_3 = np.ones((test_data.shape[1], test_data.shape[1])) # 点成 231 | for i in range(test_data.shape[1]): 232 | for j in range(test_data.shape[1]): 233 | # x = test_data[i, :] 234 | # y = test_data[j, :] 235 | x = test_data[:, i] 236 | y = test_data[:, j] 237 | d, cost_matrix, acc_cost_matrix, path = dtw(x, y, dist=euclidean_norm) 238 | matrix[i, j] = d 239 | matrix_1[i, j] = np.sum(np.abs(x-y)) 240 | matrix_2[i, j] = np.mean((x - y) ** 2) 241 | matrix_3[i, j] = float(x.dot(y)) 242 | 243 | # print(matrix) 244 | 245 | import numpy as np 246 | import seaborn as sns 247 | import matplotlib.pyplot as plt 248 | sns.set() 249 | # np.random.seed(0) 250 | # uniform_data = np.random.rand(10, 12) 251 | # ax = sns.heatmap(matrix) 252 | # f, ax = plt.subplots(figsize=(9, 6)) 253 | 254 | sns.heatmap(matrix, annot=True) 255 | plt.title('input') 256 | # os.mkdir('../heatmap_figure/lala') 257 | if os.path.exists('../heatmap_figure/lala'): 258 | plt.savefig(f'../heatmap_figure/lala/1.jpg') 259 | else: 260 | os.makedirs('../heatmap_figure/lala') 261 | plt.savefig(f'../heatmap_figure/lala/1.jpg') 262 | plt.show() 263 | 264 | plt.title('channel') 265 | sns.heatmap(matrix_3, annot=True) 266 | plt.savefig('2.jpg') 267 | plt.show() 268 | -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/mytest/__pycache__/HeatMap_DTW.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/mytest/__pycache__/HeatMap_DTW.cpython-37.pyc -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/mytest/kmeans_test.py: -------------------------------------------------------------------------------- 1 | # import cPickle 2 | # X,y = cPickle.load(open('data.pkl','r')) #X和y都是numpy.ndarray类型 3 | # X.shape #输出(1000,2) 4 | # y.shape #输出(1000,)对应每个样本的真实标签 5 | from dataset_process.dataset_process import MyDataset 6 | 7 | # path = 'E:\\PyCharmWorkSpace\\mtsdata\\ECG\\ECG.mat' # lenth=100 input=152 channel=2 output=2 8 | # path = 'E:\\PyCharmWorkSpace\\mtsdata\\JapaneseVowels\\JapaneseVowels.mat' # lenth=270 input=29 channel=12 output=9 9 | path = 'E:\\PyCharmWorkSpace\\mtsdata\\KickvsPunch\\KickvsPunch.mat' # lenth=10 input=841 channel=62 output=2 10 | 11 | 12 | 13 | dataset = MyDataset(path, 'train') 14 | X = dataset.train_dataset 15 | # X = torch.mean(X, dim=1).numpy() 16 | X = X.reshape(X.shape[0], -1).numpy() 17 | y = dataset.train_label.numpy() 18 | 19 | 20 | import numpy as np 21 | import matplotlib.pyplot as plt 22 | from utils.kmeans import KMeans 23 | def draw(X, Y): 24 | clf = KMeans(n_clusters=2, initCent='random' ,max_iter=100) 25 | clf.fit(X) 26 | cents = clf.centroids#质心 27 | labels = clf.labels#样本点被分配到的簇的索引 28 | sse = clf.sse 29 | #画出聚类结果,每一类用一种颜色 30 | colors = ['b','g','r','k','c','m','y','#e24fff','#524C90','#845868'] 31 | n_clusters = 2 32 | for i in range(n_clusters): 33 | index = np.nonzero(labels==i)[0] 34 | x0 = X[index,0] 35 | x1 = X[index,1] 36 | y_i = Y[index] 37 | for j in range(len(x0)): 38 | plt.text(x0[j],x1[j],str(int(y_i[j])),color=colors[i],\ 39 | fontdict={'weight': 'bold', 'size': 9}) 40 | plt.scatter(cents[i,0],cents[i,1],marker='x',color=colors[i],linewidths=12) 41 | plt.title("SSE={:.2f}".format(sse)) 42 | plt.axis([-30,30,-30,30]) 43 | plt.show() -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/result_figure/CharacterTrajectories 96.6% Adam epoch=100 batch=4 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/result_figure/CharacterTrajectories 96.6% Adam epoch=100 batch=4 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/result_figure/ECG 86.0% Adam epoch=10 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/result_figure/ECG 86.0% Adam epoch=10 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/result_figure/ECG 86.0% Adam epoch=20 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/result_figure/ECG 86.0% Adam epoch=20 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/result_figure/ECG 87.0% Adam epoch=50 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/result_figure/ECG 87.0% Adam epoch=50 batch=20 lr=0.0001 pe=True mask=True [512,6,6,8,8,0].png -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/result_figure/JapaneseVowels 98.11% Adagrad epoch=100 batch=3 lr=0.0001 pe=True mask=True [512,8,8,8,8,0.2].png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/result_figure/JapaneseVowels 98.11% Adagrad epoch=100 batch=3 lr=0.0001 pe=True mask=True [512,8,8,8,8,0.2].png -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/run.py: -------------------------------------------------------------------------------- 1 | # @Time : 2021/01/22 25:16 2 | # @Author : SY.M 3 | # @FileName: run.py 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from dataset_process.dataset_process import MyDataset 8 | import torch.optim as optim 9 | from time import time 10 | from tqdm import tqdm 11 | import os 12 | 13 | from module.transformer import Transformer 14 | from module.loss import Myloss 15 | from utils.random_seed import setup_seed 16 | from utils.visualization import result_visualization 17 | 18 | # from mytest.gather.main import draw 19 | 20 | setup_seed(30) # 设置随机数种子 21 | reslut_figure_path = 'result_figure' # 结果图像保存路径 22 | 23 | # 数据集路径选择 24 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\AUSLAN\\AUSLAN.mat' # lenth=1140 input=136 channel=22 output=95 25 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\CharacterTrajectories\\CharacterTrajectories.mat' 26 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\CMUsubject16\\CMUsubject16.mat' # lenth=29,29 input=580 channel=62 output=2 27 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\ECG\\ECG.mat' # lenth=100 input=152 channel=2 output=2 28 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\JapaneseVowels\\JapaneseVowels.mat' # lenth=270 input=29 channel=12 output=9 29 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\Libras\\Libras.mat' # lenth=180 input=45 channel=2 output=15 30 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\UWave\\UWave.mat' # lenth=4278 input=315 channel=3 output=8 31 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\KickvsPunch\\KickvsPunch.mat' # lenth=10 input=841 channel=62 output=2 32 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\NetFlow\\NetFlow.mat' # lenth=803 input=997 channel=4 output=只有1和13 33 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\ArabicDigits\\ArabicDigits.mat' # lenth=6600 input=93 channel=13 output=10 34 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\PEMS\\PEMS.mat' 35 | # path = 'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\Wafer\\Wafer.mat' 36 | path = 'E:\PyCharmWorkSpace\dataset\\MTS_dataset\\WalkvsRun\\WalkvsRun.mat' 37 | 38 | test_interval = 5 # 测试间隔 单位:epoch 39 | draw_key = 1 # 大于等于draw_key才会保存图像 40 | file_name = path.split('\\')[-1][0:path.split('\\')[-1].index('.')] # 获得文件名字 41 | 42 | # 超参数设置 43 | EPOCH = 100 44 | BATCH_SIZE = 3 45 | LR = 1e-4 46 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 选择设备 CPU or GPU 47 | print(f'use device: {DEVICE}') 48 | 49 | d_model = 512 50 | d_hidden = 1024 51 | q = 8 52 | v = 8 53 | h = 8 54 | N = 8 55 | dropout = 0.2 56 | pe = True # # 设置的是双塔中 score=pe score=channel默认没有pe 57 | mask = True # 设置的是双塔中 score=input的mask score=channel默认没有mask 58 | # 优化器选择 59 | optimizer_name = 'Adagrad' 60 | 61 | train_dataset = MyDataset(path, 'train') 62 | test_dataset = MyDataset(path, 'test') 63 | train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True) 64 | test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False) 65 | 66 | DATA_LEN = train_dataset.train_len # 训练集样本数量 67 | d_input = train_dataset.input_len # 时间部数量 68 | d_channel = train_dataset.channel_len # 时间序列维度 69 | d_output = train_dataset.output_len # 分类类别 70 | 71 | # 维度展示 72 | print('data structure: [lines, timesteps, features]') 73 | print(f'train data size: [{DATA_LEN, d_input, d_channel}]') 74 | print(f'mytest data size: [{train_dataset.test_len, d_input, d_channel}]') 75 | print(f'Number of classes: {d_output}') 76 | 77 | # 创建Transformer模型 78 | net = Transformer(d_model=d_model, d_input=d_input, d_channel=d_channel, d_output=d_output, d_hidden=d_hidden, 79 | q=q, v=v, h=h, N=N, dropout=dropout, pe=pe, mask=mask, device=DEVICE).to(DEVICE) 80 | # 创建loss函数 此处使用 交叉熵损失 81 | loss_function = Myloss() 82 | if optimizer_name == 'Adagrad': 83 | optimizer = optim.Adagrad(net.parameters(), lr=LR) 84 | elif optimizer_name == 'Adam': 85 | optimizer = optim.Adam(net.parameters(), lr=LR) 86 | 87 | # 用于记录准确率变化 88 | correct_on_train = [] 89 | correct_on_test = [] 90 | # 用于记录损失变化 91 | loss_list = [] 92 | time_cost = 0 93 | 94 | 95 | # 测试函数 96 | def test(dataloader, flag='test_set'): 97 | correct = 0 98 | total = 0 99 | with torch.no_grad(): 100 | net.eval() 101 | for x, y in dataloader: 102 | x, y = x.to(DEVICE), y.to(DEVICE) 103 | y_pre, _, _, _, _, _, _ = net(x, 'test') 104 | _, label_index = torch.max(y_pre.data, dim=-1) 105 | total += label_index.shape[0] 106 | correct += (label_index == y.long()).sum().item() 107 | if flag == 'test_set': 108 | correct_on_test.append(round((100 * correct / total), 2)) 109 | elif flag == 'train_set': 110 | correct_on_train.append(round((100 * correct / total), 2)) 111 | print(f'Accuracy on {flag}: %.2f %%' % (100 * correct / total)) 112 | 113 | return round((100 * correct / total), 2) 114 | 115 | 116 | # 训练函数 117 | def train(): 118 | net.train() 119 | max_accuracy = 0 120 | pbar = tqdm(total=EPOCH) 121 | begin = time() 122 | for index in range(EPOCH): 123 | for i, (x, y) in enumerate(train_dataloader): 124 | optimizer.zero_grad() 125 | 126 | y_pre, _, _, _, _, _, _ = net(x.to(DEVICE), 'train') 127 | 128 | loss = loss_function(y_pre, y.to(DEVICE)) 129 | 130 | print(f'Epoch:{index + 1}:\t\tloss:{loss.item()}') 131 | loss_list.append(loss.item()) 132 | 133 | loss.backward() 134 | 135 | optimizer.step() 136 | 137 | if ((index + 1) % test_interval) == 0: 138 | current_accuracy = test(test_dataloader) 139 | test(train_dataloader, 'train_set') 140 | print(f'当前最大准确率\t测试集:{max(correct_on_test)}%\t 训练集:{max(correct_on_train)}%') 141 | 142 | if current_accuracy > max_accuracy: 143 | max_accuracy = current_accuracy 144 | torch.save(net, f'saved_model/{file_name} batch={BATCH_SIZE}.pkl') 145 | 146 | pbar.update() 147 | 148 | os.rename(f'saved_model/{file_name} batch={BATCH_SIZE}.pkl', 149 | f'saved_model/{file_name} {max_accuracy} batch={BATCH_SIZE}.pkl') 150 | 151 | end = time() 152 | time_cost = round((end - begin) / 60, 2) 153 | 154 | # 结果图 155 | result_visualization(loss_list=loss_list, correct_on_test=correct_on_test, correct_on_train=correct_on_train, 156 | test_interval=test_interval, 157 | d_model=d_model, q=q, v=v, h=h, N=N, dropout=dropout, DATA_LEN=DATA_LEN, BATCH_SIZE=BATCH_SIZE, 158 | time_cost=time_cost, EPOCH=EPOCH, draw_key=draw_key, reslut_figure_path=reslut_figure_path, 159 | file_name=file_name, 160 | optimizer_name=optimizer_name, LR=LR, pe=pe, mask=mask) 161 | 162 | 163 | if __name__ == '__main__': 164 | train() 165 | -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/run_with_saved_model.py: -------------------------------------------------------------------------------- 1 | # @Time : 2021/01/22 25:50 2 | # @Author : SY.M 3 | # @FileName: run_with_saved_model.py 4 | 5 | 6 | import torch 7 | print('当前使用的pytorch版本:', torch.__version__) 8 | from utils.random_seed import setup_seed 9 | from torch.utils.data import DataLoader 10 | from dataset_process.dataset_process import MyDataset 11 | from utils.heatMap import heatMap_all 12 | from utils.TSNE import gather_by_tsne 13 | from utils.TSNE import gather_all_by_tsne 14 | import numpy as np 15 | from utils.colorful_line import draw_colorful_line 16 | 17 | setup_seed(30) 18 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | # 数据集维度展示 20 | # ArabicDigits length=6600 input=93 channel=13 output=10 21 | # AUSLAN length=1140 input=136 channel=22 output=95 22 | # CharacterTrajectories 23 | # CMUsubject16 length=29,29 input=580 channel=62 output=2 24 | # ECG length=100 input=152 channel=2 output=2 25 | # JapaneseVowels length=270 input=29 channel=12 output=9 26 | # Libras length=180 input=45 channel=2 output=15 27 | # UWave length=4278 input=315 channel=3 output=8 28 | # KickvsPunch length=10 input=841 channel=62 output=2 29 | # NetFlow length=803 input=997 channel=4 output=只有1和13 需要修改数据集处理的代码 30 | # Wafer length=803 input=997 channel=4 31 | 32 | # 选择要跑的模型 33 | save_model_path = 'saved_model/ECG 91.0 batch=2.pkl' 34 | file_name = save_model_path.split('/')[-1].split(' ')[0] 35 | path = f'E:\PyCharmWorkSpace\\dataset\\MTS_dataset\\{file_name}\\{file_name}.mat' # 拼装数据集路径 36 | 37 | # 绘制HeatMap的命名准备工作 38 | ACCURACY = save_model_path.split('/')[-1].split(' ')[1] # 使用的模型的准确率 39 | BATCH_SIZE = int(save_model_path[save_model_path.find('=')+1:save_model_path.rfind('.')]) # 使用的模型的batch_size 40 | heatMap_or_not = False # 是否绘制Score矩阵的HeatMap图 41 | gather_or_not = False # 是否绘制单个样本的step和channel上的聚类图 42 | gather_all_or_not = True # 是否绘制所有样本在特征提取后的聚类图 43 | 44 | # 加载模型 45 | net = torch.load(save_model_path, map_location=torch.device('cpu')) # map_location 设置使用的设备,可能是因为原来的pkl是在colab上用GPU跑的 46 | # 加载测试集数据 47 | test_dataset = MyDataset(path, 'test') 48 | test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False) 49 | 50 | print(f'step为最大值的sample个数:{len(test_dataset.max_length_sample_inTest)}') 51 | if len(test_dataset.max_length_sample_inTest) == 0: 52 | gather_or_not = False 53 | heatMap_or_not = False 54 | print('测试集中没有step为最大值的样本, 将不能绘制make sense 的heatmap 和 gather 图 可尝试换一个数据集') 55 | 56 | correct = 0 57 | total = 0 58 | with torch.no_grad(): 59 | all_sample_X = [] 60 | all_sample_Y = [] 61 | for x, y in test_dataloader: 62 | x, y = x.to(DEVICE), y.to(DEVICE) 63 | y_pre, encoding, score_input, score_channel, gather_input, gather_channel, gate = net(x.to(DEVICE), 'test') 64 | 65 | all_sample_X.append(encoding) 66 | all_sample_Y.append(y) 67 | if heatMap_or_not: 68 | for index, sample in enumerate(test_dataset.max_length_sample_inTest): 69 | if sample.numpy().tolist() in x.numpy().tolist(): 70 | target_index = x.numpy().tolist().index(sample.numpy().tolist()) 71 | print('正在绘制heatmap图...') 72 | heatMap_all(score_input[target_index], score_channel[target_index], sample, 'heatmap_figure_in_test', file_name, ACCURACY, index) 73 | print('heatmap图绘制完成!') 74 | if gather_or_not: 75 | for index, sample in enumerate(test_dataset.max_length_sample_inTest): 76 | if sample.numpy().tolist() in x.numpy().tolist(): 77 | target_index = x.numpy().tolist().index(sample.numpy().tolist()) 78 | print('正在绘制gather图...') 79 | gather_by_tsne(gather_input[target_index].numpy(), np.arange(gather_input[target_index].shape[0]), index, file_name+' input_gather') 80 | gather_by_tsne(gather_channel[target_index].numpy(), np.arange(gather_channel[target_index].shape[0]), index, file_name+' channel_gather') 81 | print('gather图绘制完成!') 82 | draw_data = x[target_index].transpose(-1, -2)[0].numpy() 83 | draw_colorful_line(draw_data) 84 | gather_or_not = False 85 | 86 | _, label_index = torch.max(y_pre.data, dim=-1) 87 | total += label_index.shape[0] 88 | correct += (label_index == y.long()).sum().item() 89 | 90 | if gather_all_or_not: 91 | all_sample_X = torch.cat(all_sample_X, dim=0).numpy() 92 | all_sample_Y = torch.cat(all_sample_Y, dim=0).numpy() 93 | print('正在绘制gather图...') 94 | gather_all_by_tsne(all_sample_X, all_sample_Y, test_dataset.output_len, file_name+' all_sample_gather') 95 | print('gather图绘制完成!') 96 | 97 | print(f'Accuracy: %.2f %%' % (100 * correct / total)) 98 | 99 | 100 | -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/utils/TSNE.py: -------------------------------------------------------------------------------- 1 | from dataset_process.dataset_process import MyDataset 2 | import torch 3 | import numpy as np 4 | import os 5 | import matplotlib.pyplot as plt 6 | 7 | # path = 'E:\\PyCharmWorkSpace\\mtsdata\\ECG\\ECG.mat' # lenth=100 input=152 channel=2 output=2 8 | # 9 | # dataset = MyDataset(path, 'train') 10 | # X = dataset.train_dataset 11 | # X = torch.mean(X, dim=1).numpy() 12 | # # X = X.reshape(X.shape[0], -1).numpy() 13 | # y = dataset.train_label.numpy() 14 | 15 | from matplotlib import cm 16 | 17 | 18 | def plot_with_labels(lowDWeights, labels, kinds, file_name): 19 | plt.cla() 20 | # 降到二维了,分别给x和y 21 | X, Y = lowDWeights[:, 0], lowDWeights[:, 1] 22 | # 遍历每个点以及对应标签 23 | for x, y, s in zip(X, Y, labels): 24 | c = cm.rainbow(int(255/kinds * s)) # 为了使得颜色有区分度,把0-255颜色区间分为9分,然后把标签映射到一个区间 25 | plt.text(x, y, s, backgroundcolor=c, fontsize=6) 26 | plt.xlim(X.min(), X.max()) 27 | plt.ylim(Y.min(), Y.max()) 28 | 29 | plt.title('Clustering Step-Wise after Embedding') 30 | plt.rcParams['figure.figsize'] = (10.0, 10.0) # 设置figure_size尺寸 31 | 32 | if os.path.exists(f'gather_figure/{file_name.split(" ")[0]}') == False: 33 | os.makedirs(f'gather_figure/{file_name.split(" ")[0]}') 34 | 35 | plt.savefig(f'gather_figure/{file_name.split(" ")[0]}/{file_name}.jpg', dpi=600) 36 | # plt.show() 37 | plt.close() 38 | 39 | def plot_only(lowDWeights, labels, index, file_name): 40 | """ 41 | 绘制聚类图并为标签打上颜色 42 | :param lowDWeights: 将为之后的用于绘制聚类图的数据 43 | :param labels: lowDWeights对应的标签 44 | :param index: 用于命名文件是进行区分 防止覆盖 45 | :param file_name: 文件名称和聚类的方式 46 | :return: None 47 | """ 48 | plt.cla() 49 | # 降到二维了,分别给x和y 50 | X, Y = lowDWeights[:, 0], lowDWeights[:, 1] 51 | # 遍历每个点以及对应标签 52 | # 聚类图中自定义的颜色的绘制请在下面for循环中完成 53 | for x, y, s in zip(X, Y, labels): 54 | position = 255 55 | if x < -850: 56 | position = 255 57 | elif 0.5*x - 225 < y: 58 | position = 0 59 | elif x < 1500: 60 | position = 50 61 | else: 62 | position = 100 63 | 64 | # c = cm.rainbow(int(255/9 * s)) # 为了使得颜色有区分度,把0-255颜色区间分为9分,然后把标签映射到一个区间 65 | c = cm.rainbow(position) # 为了使得颜色有区分度,把0-255颜色区间分为9分,然后把标签映射到一个区间 66 | plt.text(x, y, s, backgroundcolor=c, fontsize=6) 67 | plt.xlim(X.min(), X.max()) 68 | plt.ylim(Y.min(), Y.max()) 69 | # plt.title('Clustering Step-Wise after Embedding') 70 | 71 | plt.rcParams['figure.figsize'] = (10.0, 10.0) # 设置figure_size尺寸 72 | 73 | if os.path.exists(f'gather_figure/{file_name.split(" ")[0]}') == False: 74 | os.makedirs(f'gather_figure/{file_name.split(" ")[0]}') 75 | 76 | plt.savefig(f'gather_figure/{file_name.split(" ")[0]}/{file_name} {index}.jpg', dpi=600) 77 | # plt.show() 78 | plt.close() 79 | 80 | 81 | from sklearn.manifold import TSNE 82 | 83 | 84 | def gather_by_tsne(X: np.ndarray, 85 | Y: np.ndarray, 86 | index: int, 87 | file_name: str): 88 | tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=4000) # TSNE降维,降到2 89 | low_dim_embs = tsne.fit_transform(X[:, :]) 90 | labels = Y[:] 91 | plot_only(low_dim_embs, labels, index, file_name) 92 | 93 | 94 | def gather_all_by_tsne(X: np.ndarray, 95 | Y: np.ndarray, 96 | kinds: int, 97 | file_name: str): 98 | """ 99 | 对gate之后的二维数据进行聚类 100 | :param X: 聚类数据 2维数据 101 | :param Y: 聚类数据对应标签 102 | :param kinds: 分类数 103 | :param file_name: 用于文件命名 104 | :return: None 105 | """ 106 | tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=4000) # TSNE降维,降到2 107 | low_dim_embs = tsne.fit_transform(X[:, :]) 108 | labels = Y[:] 109 | plot_with_labels(low_dim_embs, labels, kinds, file_name) 110 | 111 | 112 | if __name__ == '__main__': 113 | path = 'E:\\PyCharmWorkSpace\\mtsdata\\ECG\\ECG.mat' # lenth=100 input=152 channel=2 output=2 114 | 115 | dataset = MyDataset(path, 'train') 116 | X = dataset.train_dataset 117 | X = torch.mean(X, dim=1).numpy() 118 | # X = X.reshape(X.shape[0], -1).numpy() 119 | Y = dataset.train_label.numpy() 120 | gather_by_tsne(X, Y) 121 | -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/utils/__pycache__/TSNE.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/utils/__pycache__/TSNE.cpython-37.pyc -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/utils/__pycache__/colorful_line.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/utils/__pycache__/colorful_line.cpython-37.pyc -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/utils/__pycache__/draw_line.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/utils/__pycache__/draw_line.cpython-37.pyc -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/utils/__pycache__/heatMap.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/utils/__pycache__/heatMap.cpython-37.pyc -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/utils/__pycache__/random_seed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/utils/__pycache__/random_seed.cpython-37.pyc -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/utils/__pycache__/visualization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZZUFaceBookDL/GTN/46aa05e8b609609edff193f3dba5bc0436faba2e/Gated Transformer 论文IJCAI版/utils/__pycache__/visualization.cpython-37.pyc -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/utils/colorful_line.py: -------------------------------------------------------------------------------- 1 | from matplotlib.collections import LineCollection 2 | import numpy as np 3 | import math 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def draw_colorful_line(draw_data: np.ndarray): 8 | # 定义颜色列表 与各个点一一对应 维度与点的个数相同 9 | colors = [] 10 | length = len(draw_data) 11 | x = np.arange(length) 12 | # 定义属于各个颜色的点 13 | red = [1, 2, 3] 14 | blue = [4, 5, 6] 15 | cyan = [7, 8, 9] 16 | # 填充颜色列表 17 | for i in range(length): 18 | if i in red: 19 | colors.append('red') 20 | elif i in blue: 21 | colors.append('blue') 22 | elif i in cyan: 23 | colors.append('cyan') 24 | else: 25 | colors.append('purple') 26 | y = draw_data 27 | 28 | points = np.array([x, y]).T.reshape(-1, 1, 2) # shape 152,1,2 29 | # 分片 因为颜色不能画在点上 而是画在线段上 30 | segments = np.concatenate([points[:-1], points[1:]], axis=1) # shape 151,2,2 31 | lc = LineCollection(segments, color=colors) 32 | ax = plt.axes() 33 | ax.set_xlim(0, length) 34 | ax.set_ylim(min(y), max(y)) 35 | ax.add_collection(lc) 36 | plt.show() 37 | plt.close() 38 | 39 | 40 | if __name__ == '__main__': 41 | pi = 3.1415 42 | 43 | x = np.linspace(0, 4 * pi, 100) 44 | y = [math.cos(xx) for xx in x] 45 | lwidths = abs(x) 46 | color = [] 47 | for i in range(len(y)): 48 | if i < 50: 49 | color.append('#FF0000') 50 | else: 51 | color.append('#000000') 52 | print(color) 53 | # print(x) 54 | # print(y) 55 | print('--------------------------------------') 56 | points = np.array([x, y]).T.reshape(-1, 1, 2) 57 | print(np.array([x, y]).shape) 58 | print(points.shape) 59 | print('--------------------------------------') 60 | print(points[:-1].shape) 61 | print(points[1:].shape) 62 | segments = np.concatenate([points[:-1], points[1:]], axis=1) 63 | print(segments.shape) 64 | lc = LineCollection(segments, linewidths=lwidths, color=color) 65 | 66 | ax = plt.axes() 67 | ax.set_xlim(min(x), max(x)) 68 | ax.set_ylim(min(y), max(y)) 69 | ax.add_collection(lc) 70 | plt.show() 71 | plt.close() 72 | 73 | ''' 74 | fig, a = plt.subplots() 75 | a.add_collection(lc) 76 | a.set_xlim(0, 4*pi) 77 | a.set_ylim(-1.1, 1.1) 78 | fig.show() 79 | ''' 80 | -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/utils/draw_line.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | def draw_line(list_1, list_2, list_3): 5 | plt.style.use('seaborn') 6 | fig = plt.figure() # 创建基础图 7 | ax1 = fig.add_subplot(211) 8 | ax2 = fig.add_subplot(212) 9 | 10 | ax1.plot(list_3, color='red', label='Target_Weight_for_step') 11 | ax1.set_xlabel('epoch') 12 | ax1.set_ylabel('value') 13 | ax1.set_title('One Sample Weight for Step') 14 | 15 | ax2.plot(list_1, color='red', label='Weight_for_step') 16 | ax2.plot(list_2, color='blue', label='Weight_for_channel') 17 | ax2.set_xlabel('epoch') 18 | ax2.set_ylabel('value') 19 | ax2.set_title('Mean Weight Allocation') 20 | 21 | plt.legend(loc='best') 22 | plt.show() 23 | 24 | def draw_heatmap_anylasis(sample): 25 | channel_00 = sample[:, 0].numpy() 26 | channel_11 = sample[:, 1].numpy() 27 | sample = sample.transpose(-1, -2) 28 | channel_0 = sample[0, :].numpy().tolist() 29 | channel_1 = sample[1, :].numpy().tolist() 30 | channel_6 = sample[6, :].numpy().tolist() 31 | channel_4 = sample[4, :].numpy().tolist() 32 | channel_3 = sample[3, :].numpy().tolist() 33 | channel_7 = sample[7, :].numpy().tolist() 34 | channel_11 = sample[11, :].numpy().tolist() 35 | channel_9 = sample[9, :].numpy().tolist() 36 | 37 | ax1 = plt.subplot(511) 38 | ax2 = plt.subplot(512) 39 | ax3 = plt.subplot(513) 40 | ax4 = plt.subplot(514) 41 | ax5 = plt.subplot(515) 42 | ax1.plot(channel_11, color='red', label='channel_11') 43 | ax1.plot(channel_7, label='channel_7') 44 | ax2.plot(channel_11, color='red', label='channel_11') 45 | ax2.plot(channel_6, label='channel_6') 46 | ax3.plot(channel_11, color='red', label='channel_11') 47 | ax3.plot(channel_3, label='channel_3') 48 | ax4.plot(channel_3, color='red', label='channel_3') 49 | ax4.plot(channel_4, label='channel_4') 50 | ax5.plot(channel_9, color='red', label='channel_6') 51 | ax5.plot(channel_0, label='channel_0') 52 | ax1.legend(loc='best') 53 | ax2.legend(loc='best') 54 | ax3.legend(loc='best') 55 | ax4.legend(loc='best') 56 | plt.suptitle('Feature-wise Series Compare') 57 | plt.show() -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/utils/heatMap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from dtw import dtw 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | import os 6 | import torch 7 | 8 | def heatMap_all(score_input: torch.Tensor, # 二维数据 9 | score_channel: torch.Tensor, # 二维数据 10 | x: torch.Tensor, # 二维数据 11 | save_root: str, 12 | file_name: str, 13 | accuracy: str, 14 | index: int) -> None: 15 | score_channel = score_channel.detach().numpy() 16 | score_input = score_input.detach().numpy() 17 | draw_data = x.detach().numpy() 18 | 19 | euclidean_norm = lambda x, y: np.abs(x - y) # 用于计算DTW使用的函数,此处是一个计算欧氏距离的函数 20 | 21 | matrix_00 = np.ones((draw_data.shape[1], draw_data.shape[1])) # 用于记录channel之间DTW值的矩阵 22 | # matrix_01 = np.ones((draw_data.shape[1], draw_data.shape[1])) # 用于记录channel之间相差度的矩阵 23 | 24 | # matrix_10 = np.ones((draw_data.shape[0], draw_data.shape[0])) # 用于记录input之间DTW值的矩阵 25 | matrix_11 = np.ones((draw_data.shape[0], draw_data.shape[0])) # 用于记录input之间相差度值的矩阵 26 | 27 | for i in range(draw_data.shape[0]): 28 | for j in range(draw_data.shape[0]): 29 | x = draw_data[i, :] 30 | y = draw_data[j, :] 31 | # d, cost_matrix, acc_cost_matrix, path = dtw(x, y, dist=euclidean_norm) 32 | # matrix_10[i, j] = d 33 | matrix_11[i, j] = np.sqrt(np.sum((x - y) ** 2)) 34 | 35 | draw_data = draw_data.transpose(-1, -2) 36 | for i in range(draw_data.shape[0]): 37 | for j in range(draw_data.shape[0]): 38 | x = draw_data[i, :] 39 | y = draw_data[j, :] 40 | d, cost_matrix, acc_cost_matrix, path = dtw(x, y, dist=euclidean_norm) 41 | matrix_00[i, j] = d 42 | # matrix_01[i, j] = np.mean((x - y) ** 2) 43 | 44 | plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置figure_size尺寸 45 | plt.subplot(221) 46 | sns.heatmap(score_channel, cmap="YlGnBu", vmin=0) 47 | # plt.title('channel-wise attention') 48 | 49 | plt.subplot(222) 50 | sns.heatmap(matrix_00, cmap="YlGnBu", vmin=0) 51 | # plt.title('channel-wise DTW') 52 | 53 | plt.subplot(223) 54 | sns.heatmap(score_input, cmap="YlGnBu", vmin=0) 55 | # plt.title('step-wise attention') 56 | 57 | plt.subplot(224) 58 | sns.heatmap(matrix_11, cmap="YlGnBu", vmin=0) 59 | # plt.title('step-wise L2 distance') 60 | 61 | # plt.suptitle(f'{file_name.lower()}') 62 | 63 | if os.path.exists(f'{save_root}/{file_name}') == False: 64 | os.makedirs(f'{save_root}/{file_name}') 65 | plt.savefig(f'{save_root}/{file_name}/{file_name} accuracy={accuracy} {index}.jpg', dpi=400) 66 | 67 | # plt.show() 68 | plt.close() 69 | 70 | 71 | if __name__ == '__main__': 72 | matrix = torch.Tensor(range(24)).reshape(2, 3, 4) 73 | print(matrix.shape) 74 | file_name = 'lall' 75 | epcoh = 1 76 | 77 | data_channel = matrix.detach() 78 | data_input = matrix.detach() 79 | 80 | plt.subplot(2, 2, 1) 81 | sns.heatmap(data_channel[0].data.cpu().numpy()) 82 | plt.title("1") 83 | 84 | plt.subplot(2, 2, 2) 85 | sns.heatmap(data_input[0].data.cpu().numpy()) 86 | plt.title("2") 87 | 88 | plt.subplot(2, 2, 3) 89 | sns.heatmap(data_input[0].data.cpu().numpy()) 90 | plt.title("3") 91 | 92 | plt.subplot(2, 2, 4) 93 | sns.heatmap(data_input[0].data.cpu().numpy()) 94 | plt.title("4") 95 | 96 | plt.suptitle("JapaneseVowels Attention Heat Map", fontsize='x-large', fontweight='bold') 97 | # plt.savefig('result_figure/JapaneseVowels Attention Heat Map.png') 98 | plt.show() 99 | -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/utils/kmeans.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | ''' 4 | @author: wepon, http://2hwp.com 5 | Reference: 6 | Book: <> 7 | Software: sklearn.cluster.KMeans 8 | ''' 9 | import numpy as np 10 | 11 | 12 | class KMeans(object): 13 | """ 14 | - 参数 15 | n_clusters: 16 | 聚类个数,即k 17 | initCent: 18 | 质心初始化方式,可选"random"或指定一个具体的array,默认random,即随机初始化 19 | max_iter: 20 | 最大迭代次数 21 | """ 22 | 23 | def __init__(self, n_clusters=5, initCent='random', max_iter=300): 24 | if hasattr(initCent, '__array__'): 25 | n_clusters = initCent.shape[0] 26 | self.centroids = np.asarray(initCent, dtype=np.float) 27 | else: 28 | self.centroids = None 29 | 30 | self.n_clusters = n_clusters 31 | self.max_iter = max_iter 32 | self.initCent = initCent 33 | self.clusterAssment = None 34 | self.labels = None 35 | self.sse = None 36 | 37 | # 计算两点的欧式距离 38 | 39 | def _distEclud(self, vecA, vecB): 40 | return np.linalg.norm(vecA - vecB) 41 | 42 | # 随机选取k个质心,必须在数据集的边界内 43 | def _randCent(self, X, k): 44 | n = X.shape[1] # 特征维数 45 | centroids = np.empty((k, n)) # k*n的矩阵,用于存储质心 46 | for j in range(n): # 产生k个质心,一维一维地随机初始化 47 | minJ = min(X[:, j]) 48 | rangeJ = float(max(X[:, j]) - minJ) 49 | centroids[:, j] = (minJ + rangeJ * np.random.rand(k, 1)).flatten() 50 | return centroids 51 | 52 | def fit(self, X): 53 | # 类型检查 54 | if not isinstance(X, np.ndarray): 55 | try: 56 | X = np.asarray(X) 57 | except: 58 | raise TypeError("numpy.ndarray required for X") 59 | 60 | m = X.shape[0] # m代表样本数量 61 | self.clusterAssment = np.empty((m, 2)) # m*2的矩阵,第一列存储样本点所属的族的索引值, 62 | # 第二列存储该点与所属族的质心的平方误差 63 | if self.initCent == 'random': 64 | self.centroids = self._randCent(X, self.n_clusters) 65 | 66 | clusterChanged = True 67 | for _ in range(self.max_iter): 68 | clusterChanged = False 69 | for i in range(m): # 将每个样本点分配到离它最近的质心所属的族 70 | minDist = np.inf; 71 | minIndex = -1 72 | for j in range(self.n_clusters): 73 | distJI = self._distEclud(self.centroids[j, :], X[i, :]) 74 | if distJI < minDist: 75 | minDist = distJI; 76 | minIndex = j 77 | if self.clusterAssment[i, 0] != minIndex: 78 | clusterChanged = True 79 | self.clusterAssment[i, :] = minIndex, minDist ** 2 80 | 81 | if not clusterChanged: # 若所有样本点所属的族都不改变,则已收敛,结束迭代 82 | break 83 | for i in range(self.n_clusters): # 更新质心,即将每个族中的点的均值作为质心 84 | ptsInClust = X[np.nonzero(self.clusterAssment[:, 0] == i)[0]] # 取出属于第i个族的所有点 85 | self.centroids[i, :] = np.mean(ptsInClust, axis=0) 86 | 87 | self.labels = self.clusterAssment[:, 0] 88 | self.sse = sum(self.clusterAssment[:, 1]) 89 | 90 | def predict(self, X): # 根据聚类结果,预测新输入数据所属的族 91 | # 类型检查 92 | if not isinstance(X, np.ndarray): 93 | try: 94 | X = np.asarray(X) 95 | except: 96 | raise TypeError("numpy.ndarray required for X") 97 | 98 | m = X.shape[0] # m代表样本数量 99 | preds = np.empty((m,)) 100 | for i in range(m): # 将每个样本点分配到离它最近的质心所属的族 101 | minDist = np.inf 102 | for j in range(self.n_clusters): 103 | distJI = self._distEclud(self.centroids[j, :], X[i, :]) 104 | if distJI < minDist: 105 | minDist = distJI 106 | preds[i] = j 107 | return preds 108 | 109 | 110 | class biKMeans(object): 111 | def __init__(self, n_clusters=5): 112 | self.n_clusters = n_clusters 113 | self.centroids = None 114 | self.clusterAssment = None 115 | self.labels = None 116 | self.sse = None 117 | 118 | # 计算两点的欧式距离 119 | def _distEclud(self, vecA, vecB): 120 | return np.linalg.norm(vecA - vecB) 121 | 122 | def fit(self, X): 123 | m = X.shape[0] 124 | self.clusterAssment = np.zeros((m, 2)) 125 | centroid0 = np.mean(X, axis=0).tolist() 126 | centList = [centroid0] 127 | for j in range(m): # 计算每个样本点与质心之间初始的平方误差 128 | self.clusterAssment[j, 1] = self._distEclud(np.asarray(centroid0), X[j, :]) ** 2 129 | 130 | while (len(centList) < self.n_clusters): 131 | lowestSSE = np.inf 132 | for i in range(len(centList)): # 尝试划分每一族,选取使得误差最小的那个族进行划分 133 | ptsInCurrCluster = X[np.nonzero(self.clusterAssment[:, 0] == i)[0], :] 134 | clf = KMeans(n_clusters=2) 135 | clf.fit(ptsInCurrCluster) 136 | centroidMat, splitClustAss = clf.centroids, clf.clusterAssment # 划分该族后,所得到的质心、分配结果及误差矩阵 137 | sseSplit = sum(splitClustAss[:, 1]) 138 | sseNotSplit = sum(self.clusterAssment[np.nonzero(self.clusterAssment[:, 0] != i)[0], 1]) 139 | if (sseSplit + sseNotSplit) < lowestSSE: 140 | bestCentToSplit = i 141 | bestNewCents = centroidMat 142 | bestClustAss = splitClustAss.copy() 143 | lowestSSE = sseSplit + sseNotSplit 144 | # 该族被划分成两个子族后,其中一个子族的索引变为原族的索引,另一个子族的索引变为len(centList),然后存入centList 145 | bestClustAss[np.nonzero(bestClustAss[:, 0] == 1)[0], 0] = len(centList) 146 | bestClustAss[np.nonzero(bestClustAss[:, 0] == 0)[0], 0] = bestCentToSplit 147 | centList[bestCentToSplit] = bestNewCents[0, :].tolist() 148 | centList.append(bestNewCents[1, :].tolist()) 149 | self.clusterAssment[np.nonzero(self.clusterAssment[:, 0] == bestCentToSplit)[0], :] = bestClustAss 150 | 151 | self.labels = self.clusterAssment[:, 0] 152 | self.sse = sum(self.clusterAssment[:, 1]) 153 | self.centroids = np.asarray(centList) 154 | 155 | def predict(self, X): # 根据聚类结果,预测新输入数据所属的族 156 | # 类型检查 157 | if not isinstance(X, np.ndarray): 158 | try: 159 | X = np.asarray(X) 160 | except: 161 | raise TypeError("numpy.ndarray required for X") 162 | 163 | m = X.shape[0] # m代表样本数量 164 | preds = np.empty((m,)) 165 | for i in range(m): # 将每个样本点分配到离它最近的质心所属的族 166 | minDist = np.inf 167 | for j in range(self.n_clusters): 168 | distJI = self._distEclud(self.centroids[j, :], X[i, :]) 169 | if distJI < minDist: 170 | minDist = distJI 171 | preds[i] = j 172 | return preds -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/utils/random_seed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def setup_seed(seed): 6 | torch.manual_seed(seed) 7 | torch.cuda.manual_seed_all(seed) 8 | np.random.seed(seed) 9 | torch.backends.cudnn.deterministic = True 10 | 11 | 12 | -------------------------------------------------------------------------------- /Gated Transformer 论文IJCAI版/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib.font_manager import FontProperties as fp # 1、引入FontProperties 3 | import math 4 | 5 | def result_visualization(loss_list: list, 6 | correct_on_test: list, 7 | correct_on_train: list, 8 | test_interval: int, 9 | d_model: int, 10 | q: int, 11 | v: int, 12 | h: int, 13 | N: int, 14 | dropout: float, 15 | DATA_LEN: int, 16 | BATCH_SIZE: int, 17 | time_cost: float, 18 | EPOCH: int, 19 | draw_key: int, 20 | reslut_figure_path: str, 21 | optimizer_name: str, 22 | file_name: str, 23 | LR: float, 24 | pe: bool, 25 | mask: bool): 26 | my_font = fp(fname=r"font/simsun.ttc") # 2、设置字体路径 27 | 28 | # 设置风格 29 | plt.style.use('seaborn') 30 | 31 | fig = plt.figure() # 创建基础图 32 | ax1 = fig.add_subplot(311) # 创建两个子图 33 | ax2 = fig.add_subplot(313) 34 | 35 | ax1.plot(loss_list) # 添加折线 36 | ax2.plot(correct_on_test, color='red', label='on Test Dataset') 37 | ax2.plot(correct_on_train, color='blue', label='on Train Dataset') 38 | 39 | # 设置坐标轴标签 和 图的标题 40 | ax1.set_xlabel('epoch') 41 | ax1.set_ylabel('loss') 42 | ax2.set_xlabel(f'epoch/{test_interval}') 43 | ax2.set_ylabel('correct') 44 | ax1.set_title('LOSS') 45 | ax2.set_title('CORRECT') 46 | 47 | plt.legend(loc='best') 48 | 49 | # 设置文本 50 | fig.text(x=0.13, y=0.4, s=f'最小loss:{min(loss_list)}' ' ' 51 | f'最小loss对应的epoch数:{math.ceil((loss_list.index(min(loss_list)) + 1) / math.ceil((DATA_LEN / BATCH_SIZE)))}' ' ' 52 | f'最后一轮loss:{loss_list[-1]}' '\n' 53 | f'最大correct:测试集:{max(correct_on_test)}% 训练集:{max(correct_on_train)}%' ' ' 54 | f'最大correct对应的已训练epoch数:{(correct_on_test.index(max(correct_on_test)) + 1) * test_interval}' ' ' 55 | f'最后一轮correct:{correct_on_test[-1]}%' '\n' 56 | f'd_model={d_model} q={q} v={v} h={h} N={N} drop_out={dropout}' '\n' 57 | f'共耗时{round(time_cost, 2)}分钟', FontProperties=my_font) 58 | 59 | # 保存结果图 测试不保存图(epoch少于draw_key) 60 | if EPOCH >= draw_key: 61 | plt.savefig( 62 | f'{reslut_figure_path}/{file_name} {max(correct_on_test)}% {optimizer_name} epoch={EPOCH} batch={BATCH_SIZE} lr={LR} pe={pe} mask={mask} [{d_model},{q},{v},{h},{N},{dropout}].png') 63 | 64 | # 展示图 65 | plt.show() 66 | 67 | print('正确率列表', correct_on_test) 68 | 69 | print(f'最小loss:{min(loss_list)}\r\n' 70 | f'最小loss对应的epoch数:{math.ceil((loss_list.index(min(loss_list)) + 1) / math.ceil((DATA_LEN / BATCH_SIZE)))}\r\n' 71 | f'最后一轮loss:{loss_list[-1]}\r\n') 72 | 73 | print(f'最大correct:测试集:{max(correct_on_test)}\t 训练集:{max(correct_on_train)}\r\n' 74 | f'最correct对应的已训练epoch数:{(correct_on_test.index(max(correct_on_test)) + 1) * test_interval}\r\n' 75 | f'最后一轮correct:{correct_on_test[-1]}') 76 | 77 | print(f'共耗时{round(time_cost, 2)}分钟') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gated-Transformer-on-MTS 2 | 基于Pytorch,使用改良的Transformer模型应用于多维时间序列的分类任务上 3 | 4 | ## 实验结果 5 | 对比模型选择 Fully Convolutional Networks (FCN) and Residual Net works (ResNet)
6 | 7 | 8 | DataSet|MLP|FCN|ResNet|Encoder|MCNN|t-LeNet|MCDCNN|Time-CNN|TWIESN|Gated Transformer| 9 | -------|---|---|------|-------|----|-------|------|--------|------|-----------------| 10 | ArabicDigits|96.9|99.4|**99.6**|98.1|10.0|10.0|95.9|95.8|85.3|98.8| 11 | AUSLAN|93.3|**97.5**|97.4|93.8|1.1|1.1|85.4|72.6|72.4|**97.5**| 12 | CharacterTrajectories|96.9|**99.0**|**99.0**|97.1|5.4|6.7|93.8|96.0|92.0|97.0| 13 | CMUsubject16|60.0|**100**|99.7|98.3|53.1|51.0|51.4|97.6|89.3|**100**| 14 | ECG|74.8|87.2|86.7|87.2|67.0|67.0|50.0|84.1|73.7|**91.0**| 15 | JapaneseVowels|97.6||**99.3**|99.2|97.6|9.2|23.8|94.4|95.6|96.5|98.7| 16 | Libras|78.0|**96.4**|95.4|78.3|6.7|6.7|65.1|63.7|79.4|88.9| 17 | UWave|90.1|**93.4**|92.6|90.8|12.5|12.5|84.5|8.9|75.4|91.0| 18 | KickvsPunch|61.0|54.0|51.0|61.0|54.0|50.0|56.0|62.0|67.0|**90.0**| 19 | NetFlow|55.0|89.1|62.7|77.7|77.9|72.3|63.0|89.0|94.5|**100**| 20 | PEMS|-|-|-|-|-|-|-|-|-|93.6| 21 | Wafer|89.4|98.2|98.9|98.6|89.4|89.4|65.8|94.8|94.9|**99.1**| 22 | WalkvsRun|70.0|**100**|**100**|**100**|75.0|60.0|45.0|**100.0**|94.4|**100**| 23 | 24 | ## 实验环境 25 | 环境|描述| 26 | ---|---------| 27 | 语言|Python3.7| 28 | 框架|Pytorch1.6| 29 | IDE|Pycharm and Colab| 30 | 设备|CPU and GPU| 31 | 32 | ## 数据集 33 | 多元时间序列数据集, 文件为.mat格式,训练集与测试集在一个文件中,且预先定义为了测试集数据,测试集标签,训练集数据与训练集标签。
34 | 数据集下载使用百度云盘,连接如下:
35 | 链接:https://pan.baidu.com/s/1u2HN6tfygcQvzuEK5XBa2A
36 | 提取码:dxq6
37 | Google drive link:https://drive.google.com/drive/folders/1QFadJOmbOLWMjLrcebZQR_w2fBX7x0Vm?usp=share_link 38 | 39 | UEA and UCR dataset:http://www.timeseriesclassification.com/index.php 40 | 41 | --- 42 | 43 | 数据集维度描述 44 | DataSet|Number of Classes|Size of training Set|Size of testing Set|Max Time series Length|Channel| 45 | -------|-----------------|--------------------|-------------------|----------------------|-------| 46 | ArabicDigits|10|6600|2200|93|13| 47 | AUSLAN|95|1140|1425|136|22| 48 | CharacterTrajectories|20|300|2558|205|3| 49 | CMUsubject16|2|29|29|580|62| 50 | ECG|2|100|100|152|2| 51 | JapaneseVowels|9|270|370|29|12| 52 | Libras|15|180|180|45|2| 53 | UWave|8|200|4278|315|3| 54 | KickvsPunch|2|16|10|841|62| 55 | NetFlow|2|803|534|997|4| 56 | PEMS|7|267|173|144|963| 57 | Wafer|2|298|896|198|6| 58 | WalkvsRun|2|28|16|1918|62| 59 | 60 | ## 数据预处理 61 | 详细数据集处理过程参看 dataset_process.py文件。
62 | - 创建torch.utils.data.Dataset对象,在类中对数据集进行处理,其成员变量定义的有训练集数据,训练集标签,测试集数据,测试集标签等。创建torch.utils.data.DataLoader对象,生成训练过程中的mini-batch与数据集的随机shuffle
63 | - 数据集中不同样本的Time series Length不同,处理时使用所有样本中(测试集与训练集中)**最长**的时间步作为Time series Length,使用**0**进行填充。
64 | - 数据集处理过程中保存未添加Padding的训练集数据与测试集数据,还有测试集中最长时间步的样本列表以供探索模型使用。
65 | - NetFlow数据集中标签为**1和13**,在使用此数据集时要对返回的标签值进行处理。
66 | 67 | ## 模型描述 68 | 69 | 70 | 71 | - 仅使用Encoder:由于是分类任务,模型删去传统Transformer中的decoder,**仅使用Encoder**进行分类 72 | - Two Tower:在不同的Step之间或者不同Channel之间,显然存在着诸多联系,传统Transformer使用Attentino机制用来关注不同的step或channel之间的相关程度,但仅选择一个进行计算。不同于CNN模型处理时间序列,它可以使用二维卷积核同时关注step-wise和channel-wise,在这里我们使**双塔**模型,即同时计算step-wise Attention和channel-wise Attention。 73 | - Gate机制:对于不同的数据集,不同的Attention机制有好有坏,对于双塔的特征提取的结果,简单的方法,是对两个塔的输出尽心简单的拼接,不过在这里,我们使用模型学习两个权重值,为每个塔的输出进行权重的分配,公式如下。
74 | `h = W · Concat(C, S) + b`
75 | `g1, g2 = Softmax(h)`
76 | `y = Concat(C · g1, S · g2)`
77 | - 在step-wise,模型如传统Transformer一样,添加位置编码与mask机制,而在channel-wise,模型舍弃位置编码与mask,因为对于没有时间特性的channel之间,这两个机制没有实际的意义。 78 | 79 | ## 超参描述 80 | 超参|描述| 81 | ----|---| 82 | d_model|模型处理的为时间序列而非自然语言,所以省略了NLP中对词语的编码,仅使用一个线性层映射成d_model维的稠密向量,此外,d_model保证了在每个模块衔接的地方的维度相同| 83 | d_hidden|Position-wise FeedForword 中隐藏层的维度| 84 | d_input|时间序列长度,其实是一个数据集中最长时间步的维度 **固定**的,直接由数据集预处理决定| 85 | d_channel|多元时间序列的时间通道数,即是几维的时间序列 **固定**的,直接由数据集预处理决定| 86 | d_output|分类类别数 **固定**的,直接由数据集预处理决定| 87 | q,v|Multi-Head Attention中线性层映射维度| 88 | h|Multi-Head Attention中头的数量| 89 | N|Encoder栈中Encoder的数量| 90 | dropout|随机失活| 91 | EPOCH|训练迭代次数| 92 | BATCH_SIZE|mini-batch size| 93 | LR|学习率 定义为1e-4| 94 | optimizer_name|优化器选择 建议**Adagrad**和Adam| 95 | 96 | ## 文件描述 97 | 文件名称|描述| 98 | -------|----| 99 | dataset_process|数据集处理| 100 | font|存储字体,用于结果图中的文字| 101 | gather_figure|聚类结果图| 102 | heatmap_figure_in_test|测试模型时绘制的score矩阵的热力图| 103 | module|模型的各个模块| 104 | mytest|各种测试代码| 105 | reslut_figure|准确率结果图| 106 | saved_model|保存的pkl文件| 107 | utils|工具类文件| 108 | run.py|训练模型| 109 | run_with_saved_model.py|使用训练好的模型(保存为pkl文件)测试结果| 110 | 111 | ## utils工具描述 112 | 简单介绍几个 113 | - random_seed:用于设置**随机种子**,使每一次的实验结果可复现。 114 | - heatMap.py:用于绘制双塔的score矩阵的**热力图**,用来分析channel与channel之间或者step与step之间的相关程度,用于比较的还有**DTW**矩阵和欧氏距离矩阵,用来分析决定权重分配的因素。 115 | - draw_line:用于绘制折线图,一般需要根据需要自定义新的函数进行绘制。 116 | - visualization:用于绘制训练模型的loss变化曲线和accuracy变化曲线,判断是否收敛与过拟合。 117 | - TSNE:**降维聚类算法**并绘制聚类图,用来评估模型特征提取的效果或者时间序列之间的相似性。 118 | 119 | ## Tips 120 | - .pkl文件需要先训练,并在训练结束时进行保存(设置参数为True),由于github对文件大小的限制,上传文件中不包含训练好的.pkl文件。 121 | - .pkl文件使用pycharm上的1.6版本的pytorch和colab上1.7的pytorch保存,若想load模型直接进行测试,需要测试使用的pytorch版本尽可能高于等于**1.6版本**。 122 | - 根目录文件如saved_model,reslut_figure为保存的默认路径,请勿删除或者修改名称,除非直接在源代码中对路径进行修改。 123 | - 请使用百度云盘提供的数据集,不同的MTS数据集文件格式不同,本数据集处理的是.mat文件。 124 | - utils中的工具类,在绘制彩色曲线和聚类图时,对于图中颜色的划分,由于需求不能泛化,请在函数中自行编写代码定义。 125 | - save model保存的.pkl文件在迭代过程中不断更新,在最后保存最高准确率的模型并命名,命名格式请勿修改,因为在run_with_saved_model.py中,对文件命名中的信息会加以利用,若干绘图结果的命名也会参考其中的信息。 126 | - 优先选择GPU,没有则使用CPU。 127 | 128 | ## 参考 129 | ``` 130 | [Wang et al., 2017] Z. Wang, W. Yan, and T. Oates. Time series classification from scratch with deep neural networks:A strong baseline. In 2017 International Joint Conference on Neural Networks (IJCNN), pages 1578–1585, 2017. 131 | ``` 132 | 133 | ## 本人学识浅薄,代码和文字若有不当之处欢迎批评与指正! 134 | ## 联系方式:masiyuan007@qq.com 135 | --------------------------------------------------------------------------------