├── README.md ├── GCN_models.py ├── layers.py ├── One_hot_encoder.py ├── validation.py ├── train.py ├── PEMSD7 └── W_25.csv └── ST_Transformer.py /README.md: -------------------------------------------------------------------------------- 1 | # ST-Transformer 2 | paper : <Spatial-Temporal Transformer Networks for Traffic Flow Forecasting> 3 | -------------------------------------------------------------------------------- /GCN_models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from layers import GraphConvolution 4 | 5 | 6 | class GCN(nn.Module): 7 | def __init__(self, nfeat, nhid, nclass, dropout): 8 | super(GCN, self).__init__() 9 | 10 | self.gc1 = GraphConvolution(nfeat, nhid) 11 | self.gc2 = GraphConvolution(nhid, nclass) 12 | self.dropout = dropout 13 | 14 | def forward(self, x, adj): 15 | x = F.relu(self.gc1(x, adj)) 16 | x = F.dropout(x, self.dropout, training=self.training) 17 | x = self.gc2(x, adj) 18 | return F.log_softmax(x, dim=1) 19 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from torch.nn.parameter import Parameter 6 | from torch.nn.modules.module import Module 7 | 8 | 9 | class GraphConvolution(Module): 10 | """ 11 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 12 | """ 13 | 14 | def __init__(self, in_features, out_features, bias=True): 15 | super(GraphConvolution, self).__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 19 | if bias: 20 | self.bias = Parameter(torch.FloatTensor(out_features)) 21 | else: 22 | self.register_parameter('bias', None) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | stdv = 1. / math.sqrt(self.weight.size(1)) 27 | self.weight.data.uniform_(-stdv, stdv) 28 | if self.bias is not None: 29 | self.bias.data.uniform_(-stdv, stdv) 30 | 31 | def forward(self, x, adj): 32 | support = torch.mm(x, self.weight) 33 | output = torch.spmm(adj, support) 34 | if self.bias is not None: 35 | return output + self.bias 36 | else: 37 | return output 38 | 39 | def __repr__(self): 40 | return self.__class__.__name__ + ' (' \ 41 | + str(self.in_features) + ' -> ' \ 42 | + str(self.out_features) + ')' 43 | -------------------------------------------------------------------------------- /One_hot_encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Oct 10 16:13:06 2020 4 | 5 | @author: wb 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | class One_hot_encoder(nn.Module): 11 | def __init__(self, embed_size, time_num=288): 12 | super(One_hot_encoder, self).__init__() 13 | self.time_num = time_num 14 | self.I = nn.Parameter(torch.eye(self.time_num, self.time_num, requires_grad=True)) 15 | self.onehot_Linear = nn.Linear(time_num, embed_size) 16 | 17 | def forward(self, i, N=25, T=12): 18 | 19 | if i%self.time_num+T > self.time_num : 20 | o1 = self.I[i%self.time_num : , : ] 21 | o2 = self.I[0 : (i+T)%self.time_num, : ] 22 | onehot = torch.cat((o1, o2), 0) 23 | else: 24 | onehot = self.I[i%self.time_num: i%self.time_num+T, : ] 25 | 26 | #onehot = onehot.repeat(N, 1, 1) 27 | onehot = onehot.expand(N, T, self.time_num) 28 | onehot = self.onehot_Linear(onehot) 29 | return onehot 30 | ''' 31 | def one_hot_function(i, time_num=288, N=25, T=12): 32 | 33 | I = torch.eye(time_num, time_num) 34 | 35 | if i%time_num+T > time_num : 36 | o1 = I[i%time_num : , : ] 37 | o2 = I[0 : (i+T)%time_num, : ] 38 | onehot = torch.cat((o1, o2), 0) 39 | else: 40 | onehot = I[i%time_num: i%time_num+T, : ] 41 | 42 | #onehot = onehot.repeat(N, 1, 1) 43 | onehot = onehot.expand(N, T, time_num) 44 | 45 | return onehot''' 46 | 47 | 48 | -------------------------------------------------------------------------------- /validation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Oct 9 22:20:00 2020 4 | 5 | @author: wb 6 | """ 7 | 8 | # -*- coding: utf-8 -*- 9 | """ 10 | Created on Wed Oct 7 18:25:49 2020 11 | 12 | @author: wb 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | from ST_Transformer import STTransformer 17 | import pandas as pd 18 | import numpy as np 19 | 20 | def MAE(x, y): #zi自己做MAE 21 | out = torch.abs(x-y) 22 | return out.mean(dim=0) 23 | 24 | if __name__ == "__main__": 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | days = 10 28 | val_days = 2 #需要验证天数 29 | train_num = 288*days 30 | val_num = 288*val_days 31 | row_num = train_num + val_num 32 | 33 | v = pd.read_csv("PEMSD7/V_25.csv", nrows = row_num, header= -1) 34 | A = pd.read_csv("PEMSD7/W_25.csv", header= -1) 35 | 36 | A = np.array(A) 37 | A = torch.tensor(A, dtype=torch.float32) 38 | v = np.array(v) 39 | v = v.T 40 | v = torch.tensor(v, dtype=torch.float32) 41 | 42 | 43 | in_channels=1 44 | embed_size=64 45 | time_num = 288 46 | num_layers=1 47 | T_dim=12 48 | output_T_dim=3 49 | heads=2 50 | 51 | #model = STTransformer(in_channels, embed_size, time_num, num_layers, T_dim, output_T_dim, heads) 52 | model = torch.load('model.pth') 53 | criterion1 = nn.L1Loss() #MAE 54 | criterion3 = nn.MSELoss() #RMSE 55 | 56 | 57 | for i in range( train_num , row_num-15 ): 58 | x = v[:, i:i+12] 59 | x = x.unsqueeze(0) 60 | y = v[:, i+12:i+15] 61 | 62 | out = model(x, A, i) 63 | 64 | #out=out.T 65 | #y=y.T 66 | 67 | loss1 = criterion1(out, y ) 68 | loss2 = MAE(out, y) 69 | loss3 = torch.sqrt(criterion3(out, y ) ) 70 | if i%100 == 0: 71 | #print("out", out) 72 | print("MAE loss", loss1) 73 | print("Loss2:", loss2) 74 | print("RMSE loss", loss3) 75 | 76 | 77 | 78 | 79 | #print(out) 80 | #print("输出形状", out.shape) 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Oct 7 18:25:49 2020 4 | 5 | @author: wb 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | from ST_Transformer import STTransformer 10 | import pandas as pd 11 | import numpy as np 12 | 13 | 14 | if __name__ == "__main__": 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | ''' 17 | x1 = torch.tensor( # x shape[C, N, T] 18 | [ 19 | [ 20 | [6.5, 5, 6, 4, 3, 9, 5, 2, 0], 21 | [4, 8, 7, 3, 4, 5, 6.5, 7, 2], 22 | [5, 6, 8, 9.1, 21, 4, 4, 6,20], 23 | [2, 6, 8, 1, 3, 0, 2.2, 2, 5] 24 | ] 25 | ] 26 | ).to(device) 27 | 28 | Aa = torch.tensor([ 29 | [1,0,1,0], 30 | [0,1,0,1], 31 | [2,0,1,0], 32 | [1,2,0,1.] 33 | ] 34 | ).to(device) #邻接矩阵adj''' 35 | 36 | days = 10 37 | val_days = 1 38 | 39 | train_num = 288*days 40 | val_num = 288*val_days 41 | row_num = train_num + val_num 42 | 43 | v = pd.read_csv("PEMSD7/V_25.csv", nrows = row_num, header= -1) 44 | A = pd.read_csv("PEMSD7/W_25.csv", header= -1) 45 | 46 | 47 | A = np.array(A) 48 | A = torch.tensor(A, dtype=torch.float32) 49 | 50 | v = np.array(v) 51 | v = v.T 52 | v = torch.tensor(v, dtype=torch.float32) 53 | 54 | 55 | ''' 56 | x = v[:, 0:12] 57 | y = v[:, 12:] 58 | 59 | x = x.unsqueeze(0) 60 | in_channels = v.shape[0]''' 61 | 62 | in_channels=1 63 | embed_size=64 64 | time_num = 288 #1天时间间隔数 65 | num_layers=1 66 | T_dim=12 67 | output_T_dim=3 68 | heads=1 69 | 70 | 71 | model = STTransformer(in_channels, embed_size, time_num, num_layers, T_dim, output_T_dim, heads) 72 | 73 | #optimizer = torch.optim.SGD(model.parameters(), lr=0.000001) #小数点后8位 74 | optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001) 75 | criterion = nn.L1Loss() #论文要求 76 | 77 | 78 | for i in range(train_num - 15): 79 | x = v[:, i:i+12] 80 | x = x.unsqueeze(0) 81 | y = v[:, i+12:i+15] 82 | 83 | out = model(x, A, i) 84 | loss = criterion(out, y ) 85 | 86 | if i%100 == 0: 87 | #print("out", out) 88 | print("MAE loss:", loss) 89 | 90 | #常规操作 91 | optimizer.zero_grad() 92 | loss.backward() 93 | optimizer.step() 94 | 95 | 96 | #print("输出形状", out.shape) 97 | torch.save(model, "model.pth") 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /PEMSD7/W_25.csv: -------------------------------------------------------------------------------- 1 | 0,3170,8730,11900,7760,19900,18400,2210,5890,16100,13600,4270,9100,25400,13500,5370,16000,15800,18100,17600,15500,15200,19400,15200,18700 2 | 3170,0,5630,8750,4700,16700,15300,1040,2910,13000,10500,1330,5980,22300,10400,2390,13500,13200,15000,15300,12300,12300,16300,12300,15600 3 | 8730,5630,0,3280,1040,11500,10000,6660,2910,7630,5070,4470,387,16900,4970,3390,8920,8340,10200,11000,6740,7080,11800,6880,10900 4 | 11900,8750,3280,0,4320,8190,6760,9760,6170,4360,1790,7690,2890,13600,1690,6640,8210,7390,7130,10300,3640,5590,8770,4980,7880 5 | 7760,4700,1040,4320,0,12500,11100,5730,1890,8670,6100,3490,1420,17900,6000,2390,9350,8850,11200,11300,7740,7770,12700,7660,11900 6 | 19900,16700,11500,8190,12500,0,1450,17700,14300,3840,6400,15800,11100,5560,6500,14800,11900,11000,2950,13400,5410,9230,3470,8290,3110 7 | 18400,15300,10000,6760,11100,1450,0,16200,12900,2410,4970,14300,9640,6990,5060,13300,11000,10100,2370,12700,4170,8240,3550,7270,2860 8 | 2210,1040,6660,9760,5730,17700,16200,0,3940,13900,11400,2330,7010,23200,11400,3420,14500,14100,15800,16200,13400,13400,17200,13300,16500 9 | 5890,2910,2910,6170,1890,14300,12900,3940,0,10500,7940,1630,3290,19800,7850,523,10700,10300,12900,12600,9630,9420,14400,9410,13600 10 | 16100,13000,7630,4360,8670,3840,2410,13900,10500,0,2570,12000,7240,9310,2660,11000,9440,8490,3480,11300,2190,6510,5120,5540,4240 11 | 13600,10500,5070,1790,6100,6400,4970,11400,7940,2570,0,9430,4680,11900,96.5,8410,8530,7630,5480,10600,2330,5660,7140,4820,6240 12 | 4270,1330,4470,7690,3490,15800,14300,2330,1630,12000,9430,0,4840,21300,9340,1110,12200,11800,14200,14000,11200,11000,15600,11000,14800 13 | 9100,5980,387,2890,1420,11100,9640,7010,3290,7240,4680,4840,0,16500,4580,3770,8790,8170,9820,10800,6360,6850,11400,6610,10500 14 | 25400,22300,16900,13600,17900,5560,6990,23200,19800,9310,11900,21300,16500,0,12000,20300,15500,14700,7980,16500,10400,13400,7410,12600,7690 15 | 13500,10400,4970,1690,6000,6500,5060,11400,7850,2660,96.5,9340,4580,12000,0,8310,8520,7610,5560,10600,2390,5650,7220,4820,6320 16 | 5370,2390,3390,6640,2390,14800,13300,3420,523,11000,8410,1110,3770,20300,8310,0,11200,10800,13300,13000,10100,9940,14800,9930,13900 17 | 16000,13500,8920,8210,9350,11900,11000,14500,10700,9440,8530,12200,8790,15500,8520,11200,0,950,12900,2110,7260,2940,14400,3900,13600 18 | 15800,13200,8340,7390,8850,11000,10100,14100,10300,8490,7630,11800,8170,14700,7610,10800,950,0,11900,2960,6310,2000,13500,2950,12600 19 | 18100,15000,10200,7130,11200,2950,2370,15800,12900,3480,5480,14200,9820,7980,5560,13300,12900,11900,0,14700,5660,9970,1670,8990,773 20 | 17600,15300,11000,10300,11300,13400,12700,16200,12600,11300,10600,14000,10800,16500,10600,13000,2110,2960,14700,0,9140,4930,16200,5830,15400 21 | 15500,12300,6740,3640,7740,5410,4170,13400,9630,2190,2330,11200,6360,10400,2390,10100,7260,6310,5660,9140,0,4330,7280,3360,6410 22 | 15200,12300,7080,5590,7770,9230,8240,13400,9420,6510,5660,11000,6850,13400,5650,9940,2940,2000,9970,4930,4330,0,11600,979,10700 23 | 19400,16300,11800,8770,12700,3470,3550,17200,14400,5120,7140,15600,11400,7410,7220,14800,14400,13500,1670,16200,7280,11600,0,10600,902 24 | 15200,12300,6880,4980,7660,8290,7270,13300,9410,5540,4820,11000,6610,12600,4820,9930,3900,2950,8990,5830,3360,979,10600,0,9730 25 | 18700,15600,10900,7880,11900,3110,2860,16500,13600,4240,6240,14800,10500,7690,6320,13900,13600,12600,773,15400,6410,10700,902,9730,0 26 | -------------------------------------------------------------------------------- /ST_Transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Sep 28 10:28:06 2020 4 | 5 | @author: wb 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from GCN_models import GCN 11 | from One_hot_encoder import One_hot_encoder 12 | 13 | class SSelfAttention(nn.Module): 14 | def __init__(self, embed_size, heads): 15 | super(SSelfAttention, self).__init__() 16 | self.embed_size = embed_size 17 | self.heads = heads 18 | self.head_dim = embed_size // heads 19 | 20 | assert ( 21 | self.head_dim * heads == embed_size 22 | ), "Embedding size needs to be divisible by heads" 23 | 24 | self.values = nn.Linear(self.head_dim, self.head_dim, bias=False) 25 | self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False) 26 | self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False) 27 | self.fc_out = nn.Linear(heads * self.head_dim, embed_size) 28 | 29 | def forward(self, values, keys, query): 30 | # Get number of training examples 31 | N, T, C = query.shape 32 | 33 | # Split the embedding into self.heads different pieces 34 | values = values.reshape(N, T, self.heads, self.head_dim) #512维拆成heads×head_dim 35 | keys = keys.reshape(N, T, self.heads, self.head_dim) 36 | query = query.reshape(N, T, self.heads, self.head_dim) 37 | 38 | values = self.values(values) # (N, T, heads, head_dim) 39 | keys = self.keys(keys) # (N, T, heads, head_dim) 40 | queries = self.queries(query) # (N, T, heads, heads_dim) 41 | 42 | # Einsum does matrix mult. for query*keys for each training example 43 | # with every other training example, don't be confused by einsum 44 | # it's just how I like doing matrix multiplication & bmm 45 | 46 | energy = torch.einsum("qthd,kthd->qkth", [queries, keys])#空间self-attention 47 | # queries shape: (N, T, heads, heads_dim), 48 | # keys shape: (N, T, heads, heads_dim) 49 | # energy: (N, N, T, heads) 50 | 51 | # Normalize energy values similarly to seq2seq + attention 52 | # so that they sum to 1. Also divide by scaling factor for 53 | # better stability 54 | attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=1)#在K维做softmax,和为1 55 | # attention shape: (N, N, T, heads) 56 | 57 | out = torch.einsum("qkth,kthd->qthd", [attention, values]).reshape( 58 | N, T, self.heads * self.head_dim 59 | ) 60 | # attention shape: (N, N, T, heads) 61 | # values shape: (N, T, heads, heads_dim) 62 | # out after matrix multiply: (N, T, heads, head_dim), then 63 | # we reshape and flatten the last two dimensions. 64 | 65 | out = self.fc_out(out) 66 | # Linear layer doesn't modify the shape, final shape will be 67 | # (N, T, embed_size) 68 | 69 | return out 70 | 71 | class TSelfAttention(nn.Module): 72 | def __init__(self, embed_size, heads): 73 | super(TSelfAttention, self).__init__() 74 | self.embed_size = embed_size 75 | self.heads = heads 76 | self.head_dim = embed_size // heads 77 | 78 | assert ( 79 | self.head_dim * heads == embed_size 80 | ), "Embedding size needs to be divisible by heads" 81 | 82 | self.values = nn.Linear(self.head_dim, self.head_dim, bias=False) 83 | self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False) 84 | self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False) 85 | self.fc_out = nn.Linear(heads * self.head_dim, embed_size) 86 | 87 | def forward(self, values, keys, query): 88 | # Get number of training examples 89 | N, T, C = query.shape 90 | 91 | # Split the embedding into self.heads different pieces 92 | values = values.reshape(N, T, self.heads, self.head_dim) #512维拆成heads×head_dim 93 | keys = keys.reshape(N, T, self.heads, self.head_dim) 94 | query = query.reshape(N, T, self.heads, self.head_dim) 95 | 96 | values = self.values(values) # (N, T, heads, head_dim) 97 | keys = self.keys(keys) # (N, T, heads, head_dim) 98 | queries = self.queries(query) # (N, T, heads, heads_dim) 99 | 100 | # Einsum does matrix mult. for query*keys for each training example 101 | # with every other training example, don't be confused by einsum 102 | # it's just how I like doing matrix multiplication & bmm 103 | energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])#时间self-attention 104 | # queries shape: (N, T, heads, heads_dim), 105 | # keys shape: (N, T, heads, heads_dim) 106 | # energy: (N, heads, T, T) 107 | 108 | # Normalize energy values similarly to seq2seq + attention 109 | # so that they sum to 1. Also divide by scaling factor for 110 | # better stability 111 | attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)#在K维做softmax,和为1 112 | # attention shape: (N, heads, query_len, key_len) 113 | 114 | out = torch.einsum("nhqk,nkhd->nqhd", [attention, values]).reshape( 115 | N, T, self.heads * self.head_dim 116 | ) 117 | # attention shape: (N, heads, T, T) 118 | # values shape: (N, T, heads, heads_dim) 119 | # out after matrix multiply: (N, T, heads, head_dim), then 120 | # we reshape and flatten the last two dimensions. 121 | 122 | out = self.fc_out(out) 123 | # Linear layer doesn't modify the shape, final shape will be 124 | # (N, T, embed_size) 125 | 126 | return out 127 | 128 | 129 | class STransformer(nn.Module): 130 | def __init__(self, embed_size, heads, dropout, forward_expansion): 131 | super(STransformer, self).__init__() 132 | self.attention = SSelfAttention(embed_size, heads) 133 | self.norm1 = nn.LayerNorm(embed_size) 134 | self.norm2 = nn.LayerNorm(embed_size) 135 | 136 | self.feed_forward = nn.Sequential( 137 | nn.Linear(embed_size, forward_expansion * embed_size), 138 | nn.ReLU(), 139 | nn.Linear(forward_expansion * embed_size, embed_size), 140 | ) 141 | 142 | self.gcn = GCN(embed_size, embed_size, embed_size, dropout) #调用GCN 143 | self.norm_gcn = nn.InstanceNorm2d(1) 144 | 145 | self.dropout = nn.Dropout(dropout) 146 | self.out1_fc = nn.Linear(embed_size, embed_size) 147 | self.out2_fc = nn.Linear(embed_size, embed_size) 148 | 149 | def forward(self, value, key, query, adj): 150 | #Spatial Transformer 部分 151 | 152 | #adj = adj.unsqueeze(2) 153 | #adj = adj.expand(4, 4, 64) #拼接邻接矩阵 154 | #query = torch.cat((query, adj), 1) 155 | 156 | attention = self.attention(value, key, query) 157 | # Add skip connection, run through normalization and finally dropout 158 | x = self.dropout(self.norm1(attention + query)) 159 | forward = self.feed_forward(x) 160 | out1 = self.dropout(self.norm2(forward + x)) 161 | 162 | 163 | # GCN 部分 164 | out2 = torch.Tensor(query.shape[0], 0, query.shape[2]) 165 | adj = adj.unsqueeze(0).unsqueeze(0) 166 | adj = self.norm_gcn(adj) 167 | adj = adj.squeeze(0).squeeze(0) 168 | 169 | for t in range(query.shape[1]): 170 | o = self.gcn(query[ : , t, : ], adj) 171 | o = o.unsqueeze(1) # shape [N, T, C] 172 | out2 = torch.cat((out2, o), dim=1) 173 | 174 | 175 | # 融合 STransformer and GCN 176 | g = torch.sigmoid( self.out1_fc(out1) + self.out2_fc(out2) ) 177 | out = g*out1 + (1-g)*out2 178 | 179 | return out 180 | 181 | class TTransformer(nn.Module): 182 | def __init__(self, embed_size, heads, time_num, dropout, forward_expansion): 183 | super(TTransformer, self).__init__() 184 | # Temporal embedding One hot 185 | self.time_num = time_num 186 | self.one_hot = One_hot_encoder(embed_size, time_num) 187 | 188 | 189 | self.attention = TSelfAttention(embed_size, heads) 190 | self.norm1 = nn.LayerNorm(embed_size) 191 | self.norm2 = nn.LayerNorm(embed_size) 192 | 193 | self.feed_forward = nn.Sequential( 194 | nn.Linear(embed_size, forward_expansion * embed_size), 195 | nn.ReLU(), 196 | nn.Linear(forward_expansion * embed_size, embed_size), 197 | ) 198 | self.dropout = nn.Dropout(dropout) 199 | 200 | def forward(self, value, key, query, i): 201 | 202 | onehot_encoder = self.one_hot(i, N=query.shape[0], T=query.shape[1]) 203 | 204 | query = query + onehot_encoder 205 | 206 | attention = self.attention(value, key, query) 207 | 208 | # Add skip connection, run through normalization and finally dropout 209 | x = self.dropout(self.norm1(attention + query)) 210 | forward = self.feed_forward(x) 211 | out = self.dropout(self.norm2(forward + x)) 212 | return out 213 | 214 | class STTransformerBlock(nn.Module): 215 | def __init__(self, embed_size, heads, time_num, dropout, forward_expansion): 216 | super(STTransformerBlock, self).__init__() 217 | self.STransformer = STransformer(embed_size, heads, dropout, forward_expansion) 218 | self.TTransformer = TTransformer(embed_size, heads, time_num, dropout, forward_expansion) 219 | 220 | def forward(self, value, key, query, adj, i): 221 | x1 = self.STransformer(value, key, query, adj) + query 222 | x2 = self.TTransformer(x1, x1, x1, i) + x1 223 | 224 | return x2 225 | 226 | class Encoder(nn.Module): 227 | def __init__( 228 | self, 229 | embed_size, 230 | num_layers, 231 | heads, 232 | time_num, 233 | device, 234 | forward_expansion, 235 | dropout, 236 | ): 237 | 238 | super(Encoder, self).__init__() 239 | self.embed_size = embed_size 240 | self.device = device 241 | self.layers = nn.ModuleList( 242 | [ 243 | STTransformerBlock( 244 | embed_size, 245 | heads, 246 | time_num, 247 | dropout=dropout, 248 | forward_expansion=forward_expansion 249 | ) 250 | for _ in range(num_layers) 251 | ] 252 | ) 253 | 254 | self.dropout = nn.Dropout(dropout) 255 | 256 | def forward(self, x, adj, i): 257 | N, T, C = x.shape 258 | out = self.dropout(x) 259 | 260 | # In the Encoder the query, key, value are all the same, it's in the 261 | # decoder this will change. This might look a bit odd in this case. 262 | for layer in self.layers: 263 | out = layer(out, out, out, adj, i) 264 | return out 265 | 266 | class Transformer(nn.Module): 267 | def __init__( 268 | self, 269 | embed_size=512, 270 | num_layers=3, 271 | heads=8, 272 | time_num=288, 273 | forward_expansion=4, 274 | dropout=0, 275 | device="cpu", 276 | ): 277 | 278 | super(Transformer, self).__init__() 279 | self.encoder = Encoder( 280 | embed_size, 281 | num_layers, 282 | heads, 283 | time_num, 284 | device, 285 | forward_expansion, 286 | dropout, 287 | ) 288 | 289 | self.device = device 290 | 291 | def forward(self, src, adj, i): 292 | enc_src = self.encoder(src, adj, i) 293 | return enc_src 294 | 295 | 296 | class STTransformer(nn.Module): 297 | def __init__(self, 298 | in_channels = 1, 299 | embed_size = 512, 300 | time_num = 288, 301 | num_layers = 3, 302 | T_dim = 12, 303 | output_T_dim = 3, #第二次卷积输出通道数 304 | heads = 2, 305 | ): 306 | 307 | super(STTransformer, self).__init__() 308 | 309 | self.conv1 = nn.Conv2d(in_channels, embed_size, 1) 310 | self.Transformer = Transformer(embed_size, num_layers, heads=heads, time_num=time_num) 311 | self.conv2 = nn.Conv2d(T_dim, output_T_dim, 1) 312 | self.conv3 = nn.Conv2d(embed_size, 1, 1) 313 | 314 | def forward(self, x, adj, i): 315 | # x shape[ C, N, T] 316 | x = x.unsqueeze(0) 317 | input_Transformer = self.conv1(x) 318 | input_Transformer = input_Transformer.squeeze(0) 319 | input_Transformer = input_Transformer.permute(1, 2, 0) 320 | 321 | #input_Transformer shape[N, T, C] 322 | output_Transformer = self.Transformer(input_Transformer, adj, i) 323 | 324 | output_Transformer = output_Transformer.permute(1, 0, 2) 325 | #output_Transformer shape[T, N, C] 326 | 327 | output_Transformer = output_Transformer.unsqueeze(0) 328 | out = self.conv2(output_Transformer) #out shape: [1, output_T_dim, N, C] 329 | 330 | out = out.permute(0, 3, 2, 1) #out shape: [1, C, N, output_T_dim] 331 | out = self.conv3( out ) #out shape: [1, 1, N, output_T_dim] 332 | 333 | out = out.squeeze(0).squeeze(0) 334 | 335 | return out 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | --------------------------------------------------------------------------------