├── metric └── main_metric.py ├── block ├── embed_block.py ├── revin.py ├── decoder_block.py └── TVA_block.py ├── README.md └── main_model.py /metric/main_metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def RSE(pred, true): 4 | return np.sqrt(np.sum((true-pred)**2)) / np.sqrt(np.sum((true-true.mean())**2)) 5 | 6 | def CORR(pred, true): 7 | u = ((true-true.mean(0))*(pred-pred.mean(0))).sum(0) 8 | d = np.sqrt(((true-true.mean(0))**2*(pred-pred.mean(0))**2).sum(0)) 9 | return (u/d).mean() 10 | 11 | def Corr(pred, true): 12 | sig_p = np.std(pred, axis=0) 13 | sig_g = np.std(true, axis=0) 14 | m_p = pred.mean(0) 15 | m_g = true.mean(0) 16 | ind = (sig_g != 0) 17 | corr = ((pred - m_p) * (true - m_g)).mean(0) / (sig_p * sig_g) 18 | corr = (corr[ind]).mean() 19 | return corr 20 | 21 | def MAE(pred, true): 22 | return np.mean(np.abs(pred-true)) 23 | 24 | def MSE(pred, true): 25 | return np.mean((pred-true)**2) 26 | 27 | def RMSE(pred, true): 28 | return np.sqrt(MSE(pred, true)) 29 | 30 | def MAPE(pred, true): 31 | return np.mean(np.abs((pred - true) / true)) 32 | 33 | def MSPE(pred, true): 34 | return np.mean(np.square((pred - true) / true)) 35 | 36 | def metric(pred, true): 37 | mae = MAE(pred, true) 38 | mse = MSE(pred, true) 39 | rmse = RMSE(pred, true) 40 | mape = MAPE(pred, true) 41 | mspe = MSPE(pred, true) 42 | #corr1 = CORR(pred, true) 43 | corr = Corr(pred, true) 44 | return mae,mse,rmse,mape,mspe,corr 45 | -------------------------------------------------------------------------------- /block/embed_block.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn, optim 4 | import torch.nn.functional as F 5 | 6 | 7 | class embed(nn.Module): 8 | def __init__(self,Input_len, num_id,num_samp,IF_node): 9 | super(embed, self).__init__() 10 | self.IF_node = IF_node 11 | self.num_samp = num_samp 12 | self.embed_layer = nn.Linear(2*Input_len,Input_len) 13 | 14 | self.node_emb = nn.Parameter(torch.empty(num_id, Input_len)) 15 | nn.init.xavier_uniform_(self.node_emb) 16 | 17 | def forward(self, x): 18 | 19 | x = x.unsqueeze(-1) 20 | batch_size, _, _ ,_ = x.shape 21 | node_emb1 = self.node_emb.unsqueeze(0).expand(batch_size, -1, -1).unsqueeze(-1) 22 | 23 | x_1 = embed.down_sampling(x, self.num_samp) 24 | if self.IF_node: 25 | x_1 = torch.cat([x_1, embed.down_sampling(node_emb1, self.num_samp)], dim=-1) 26 | 27 | x_2 = embed.Interval_sample(x, self.num_samp) 28 | if self.IF_node: 29 | x_2 = torch.cat([x_2, embed.Interval_sample(node_emb1, self.num_samp)], dim=-1) 30 | 31 | return x_1,x_2 32 | 33 | @staticmethod 34 | def down_sampling(data,n): 35 | result = 0.0 36 | for i in range(n): 37 | line = data[:,:,i::n,:] 38 | if i == 0: 39 | result = line 40 | else: 41 | result = torch.cat([result, line], dim=3) 42 | result = result.transpose(2, 3) 43 | return result 44 | 45 | @staticmethod 46 | def Interval_sample(data,n): 47 | result = 0.0 48 | data_len = data.shape[2] // n 49 | for i in range(n): 50 | line = data[:,:,data_len*i:data_len*(i+1),:] 51 | if i == 0: 52 | result = line 53 | else: 54 | result = torch.cat([result, line], dim=3) 55 | result = result.transpose(2, 3) 56 | return result -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DSformer 2 | 3 | This github repository corresponds to our paper published in CIKM 2023(Dsformer: A double sampling transformer for multivariate time series long-term prediction). 4 | 5 | To unified manage all baselines and models from our lab, DSformer's code will be stored together with other baselines at the following link: https://github.com/zezhishao/BasicTS 6 | 7 | The complete parameter Settings and training pipline are stored in the above link. 8 | 9 | The current repository stores the model files for DSformer. Please note that we have optimized the code for DSformer in order to comply with commercial regulations. After testing, the current version has improved performance. 10 | 11 | The core hyperparameters include the following parts: 12 | - Input_len: History length 13 | - out_len:Future length 14 | - num_id:Number of variables 15 | - num_layer:Number of layers. 1 or 2 (Note: In most cases, 1 is enough) 16 | - muti_head:Number of muti_head attention. 1 to 4 (Note: In most cases, 1 or 2 is enough) 17 | - dropout:dropout. 0.15 to 0.3 18 | - num_samp:Number of subsequence. 2 or 3 19 | - IF_node: Whether to use variable embedding. True or False (Note: In most cases, set to True) 20 | 21 | In addition, some hyperparameters related to the learning rate are given as follows: 22 | - Initial learning rate: 0.0002 (Note: In the team's BasicTS environment, setting to 0.002 might be better. In this case, milestone = [1,5,15,25,50,75,100], gamme = 0.5) 23 | - Learning rate decay strategy:MultiStepLR 24 | - milestone = [1,15,25,50,75,100], gamme = 0.5 25 | - clip_grad_norm_: max_norm = 3 26 | - batch size: 32 27 | 28 | 29 | If the code is helpful to you, please cite the following paper: 30 | ```bibtex 31 | @inproceedings{yu2023dsformer, 32 | title={Dsformer: A double sampling transformer for multivariate time series long-term prediction}, 33 | author={Yu, Chengqing and Wang, Fei and Shao, Zezhi and Sun, Tao and Wu, Lin and Xu, Yongjun}, 34 | booktitle={Proceedings of the 32nd ACM International Conference on Information and Knowledge Management}, 35 | pages={3062--3072}, 36 | year={2023} 37 | } 38 | ``` 39 | -------------------------------------------------------------------------------- /main_model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn, optim 4 | import torch.nn.functional as F 5 | from block.embed_block import embed 6 | from block.TVA_block import TVA_block_att 7 | from block.decoder_block import TVADE_block 8 | from block.revin import RevIN 9 | 10 | class DSFormer(nn.Module): 11 | def __init__(self, Input_len, out_len, num_id, num_layer, dropout, muti_head, num_samp, IF_node): 12 | """ 13 | Input_len: History length 14 | out_len:future length 15 | num_id:number of variables 16 | num_layer:number of layer. 1 or 2 17 | muti_head:number of muti_head attention. 1 to 4 18 | dropout:dropout. 0.15 to 0.3 19 | num_samp:muti_head subsequence. 2 or 3 20 | IF_node:Whether to use node embedding. True or False 21 | """ 22 | super(DSFormer, self).__init__() 23 | 24 | if IF_node: 25 | self.inputlen = 2 * Input_len // num_samp 26 | else: 27 | self.inputlen = Input_len // num_samp 28 | 29 | ### embed and encoder 30 | self.RevIN = RevIN(num_id) 31 | self.embed_layer = embed(Input_len,num_id,num_samp,IF_node) 32 | self.encoder = TVA_block_att(self.inputlen,num_id,num_layer,dropout, muti_head,num_samp) 33 | self.laynorm = nn.LayerNorm([self.inputlen]) 34 | 35 | ### decorder 36 | self.decoder = TVADE_block(self.inputlen, num_id, dropout, muti_head) 37 | self.output = nn.Conv1d(in_channels = self.inputlen, out_channels=out_len, kernel_size=1) 38 | 39 | def forward(self, x): 40 | # Input [B,H,N]: B is batch size. N is the number of variables. H is the history length 41 | # Output [B,L,N]: B is batch size. N is the number of variables. L is the future length 42 | 43 | ### embed 44 | x = self.RevIN(x,'norm').transpose(-2,-1) 45 | x_1, x_2 = self.embed_layer(x) 46 | 47 | ### encoder 48 | x_1 = self.encoder(x_1) 49 | x_2 = self.encoder(x_2) 50 | x = x_1 + x_2 51 | x = self.laynorm(x) 52 | 53 | ### decorder 54 | x = self.decoder(x) 55 | x = self.output(x.transpose(-2,-1)) 56 | x = self.RevIN(x, 'denorm') 57 | 58 | return x 59 | -------------------------------------------------------------------------------- /block/revin.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/ts-kim/RevIN, with minor modifications 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class RevIN(nn.Module): 7 | def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False): 8 | """ 9 | :param num_features: the number of features or channels 10 | :param eps: a value added for numerical stability 11 | :param affine: if True, RevIN has learnable affine parameters 12 | """ 13 | super(RevIN, self).__init__() 14 | self.num_features = num_features 15 | self.eps = eps 16 | self.affine = affine 17 | self.subtract_last = subtract_last 18 | if self.affine: 19 | self._init_params() 20 | 21 | def forward(self, x, mode:str): 22 | if mode == 'norm': 23 | self._get_statistics(x) 24 | x = self._normalize(x) 25 | elif mode == 'denorm': 26 | x = self._denormalize(x) 27 | else: raise NotImplementedError 28 | return x 29 | 30 | def _init_params(self): 31 | # initialize RevIN params: (C,) 32 | self.affine_weight = nn.Parameter(torch.ones(self.num_features)) 33 | self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) 34 | 35 | def _get_statistics(self, x): 36 | dim2reduce = tuple(range(1, x.ndim-1)) 37 | if self.subtract_last: 38 | self.last = x[:,-1,:].unsqueeze(1) 39 | else: 40 | self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() 41 | self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() 42 | 43 | def _normalize(self, x): 44 | if self.subtract_last: 45 | x = x - self.last 46 | else: 47 | x = x - self.mean 48 | x = x / self.stdev 49 | if self.affine: 50 | x = x * self.affine_weight 51 | x = x + self.affine_bias 52 | return x 53 | 54 | def _denormalize(self, x): 55 | if self.affine: 56 | x = x - self.affine_bias 57 | x = x / (self.affine_weight + self.eps*self.eps) 58 | x = x * self.stdev 59 | if self.subtract_last: 60 | x = x + self.last 61 | else: 62 | x = x + self.mean 63 | return x 64 | -------------------------------------------------------------------------------- /block/decoder_block.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn, optim 4 | import torch.nn.functional as F 5 | 6 | class TVADE_block(nn.Module): 7 | def __init__(self,Input_len, num_id,dropout, num_head=1): 8 | super(TVADE_block, self).__init__() 9 | self.Time_att = Time_de_att(Input_len,dropout,num_head) 10 | self.space_att = space_att2(Input_len, num_id, dropout, num_head) 11 | self.cross_att = cross_de_att(Input_len, dropout, num_head) 12 | 13 | def forward(self, x): 14 | 15 | x = self.cross_att(self.Time_att(x),self.space_att(x)) 16 | 17 | return x 18 | 19 | ### temporal attention 20 | class Time_de_att(nn.Module): 21 | def __init__(self, dim_input,dropout,num_head): 22 | super(Time_de_att, self).__init__() 23 | self.query = nn.Conv1d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 24 | self.key = nn.Conv1d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 25 | self.value = nn.Conv1d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 26 | self.laynorm = nn.LayerNorm([dim_input]) 27 | self.softmax = nn.Softmax(dim=-1) 28 | self.num_head = num_head 29 | self.dropout = nn.Dropout(dropout) 30 | self.output = nn.Conv2d(in_channels=dim_input,out_channels=dim_input,kernel_size=(1,num_head)) 31 | def forward(self, x): 32 | x = x.transpose(-2, -1) 33 | result = 0.0 34 | for i in range(self.num_head): 35 | q = self.dropout(self.query(x)).transpose(-2, -1) 36 | k = self.dropout(self.key(x)) 37 | v = self.dropout(self.value(x)).transpose(-2, -1) 38 | kd = torch.sqrt(torch.tensor(k.shape[-1]).to(torch.float32)/self.num_head) 39 | line = self.dropout(self.softmax(q @ k / kd)) @ v 40 | if i < 1: 41 | result = line.unsqueeze(-1) 42 | else: 43 | result = torch.cat([result,line.unsqueeze(-1)],dim=-1) 44 | result = self.output(result.transpose(1, 2)) 45 | result = result.squeeze(-1) 46 | x = x + result 47 | x = x.transpose(-2, -1) 48 | x = self.laynorm(x) 49 | return x 50 | 51 | ### space_attention 52 | class space_att(nn.Module): 53 | def __init__(self, Input_len, dim_input,dropout,num_head): 54 | super(space_att, self).__init__() 55 | self.query = nn.Conv1d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 56 | self.key = nn.Conv1d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 57 | self.value = nn.Conv1d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 58 | self.softmax = nn.Softmax(dim=-1) 59 | self.num_head = num_head 60 | self.linear1 = nn.Linear(num_head, 1) 61 | self.dropout = nn.Dropout(dropout) 62 | def forward(self, x): 63 | result = 0.0 64 | for i in range(self.num_head): 65 | q = self.dropout(self.query(x)).transpose(-2, -1) 66 | k = self.dropout(self.key(x)) 67 | v = self.dropout(self.value(x)).transpose(-2, -1) 68 | kd = torch.sqrt(torch.tensor(k.shape[-1]).to(torch.float32)/self.num_head) 69 | 70 | line = self.dropout(self.softmax(q@k/kd))@ v 71 | if i < 1: 72 | result = line.unsqueeze(-1) 73 | else: 74 | result = torch.cat([result,line.unsqueeze(-1)],dim=-1) 75 | result = self.linear1(result) 76 | result = result.squeeze(-1) 77 | result = result.transpose(-2,-1) 78 | return result 79 | 80 | ### space_attention2 81 | class space_att2(nn.Module): 82 | def __init__(self, Input_len, dim_input,dropout,num_head): 83 | super(space_att2, self).__init__() 84 | self.query = nn.Linear(dim_input, dim_input) 85 | self.key = nn.Linear(dim_input, dim_input) 86 | self.value = nn.Linear(dim_input, dim_input) 87 | self.softmax = nn.Softmax(dim=-1) 88 | self.num_head = num_head 89 | self.linear1 = nn.Linear(num_head, 1) 90 | self.dropout = nn.Dropout(dropout) 91 | def forward(self, x): 92 | 93 | x = x.transpose(1, 2) 94 | result = 0.0 95 | q = self.dropout(self.query(x)) 96 | k = self.dropout(self.key(x)) 97 | k = k.transpose(-2, -1) 98 | v = self.dropout(self.value(x)) 99 | kd = torch.sqrt(torch.tensor(k.shape[-1]).to(torch.float32) / self.num_head) 100 | 101 | for i in range(self.num_head): 102 | line = self.dropout(self.softmax(q@k/kd))@ v 103 | if i < 1: 104 | result = line.unsqueeze(-1) 105 | else: 106 | result = torch.cat([result,line.unsqueeze(-1)],dim=-1) 107 | result = self.linear1(result) 108 | result = result.squeeze(-1) 109 | result = result.transpose(1, 2) 110 | return result 111 | 112 | ### cross attention 113 | class cross_de_att(nn.Module): 114 | def __init__(self, dim_input,dropout,num_head): 115 | super(cross_de_att, self).__init__() 116 | self.query = nn.Conv1d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 117 | self.key = nn.Conv1d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 118 | self.value = nn.Conv1d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 119 | self.laynorm = nn.LayerNorm([dim_input]) 120 | self.softmax = nn.Softmax(dim=-1) 121 | self.num_head = num_head 122 | self.dropout = nn.Dropout(dropout) 123 | self.output = nn.Conv2d(in_channels=dim_input,out_channels=dim_input,kernel_size=(1,num_head)) 124 | def forward(self, x, x2): 125 | x = x.transpose(-2, -1) 126 | x2 = x2.transpose(-2, -1) 127 | result = 0.0 128 | for i in range(self.num_head): 129 | q = self.dropout(self.query(x2)).transpose(-2, -1) 130 | k = self.dropout(self.key(x)) 131 | v = self.dropout(self.value(x)).transpose(-2, -1) 132 | kd = torch.sqrt(torch.tensor(k.shape[-1]).to(torch.float32)/self.num_head) 133 | line = self.dropout(self.softmax(q @ k / kd)) @ v 134 | if i < 1: 135 | result = line.unsqueeze(-1) 136 | else: 137 | result = torch.cat([result,line.unsqueeze(-1)],dim=-1) 138 | result = self.output(result.transpose(1, 2)) 139 | result = result.squeeze(-1) 140 | x = x + result 141 | x = x.transpose(-2, -1) 142 | x = self.laynorm(x) 143 | return x 144 | -------------------------------------------------------------------------------- /block/TVA_block.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn, optim 4 | import torch.nn.functional as F 5 | 6 | class TVA_block_att(nn.Module): 7 | def __init__(self,Input_len, num_id,num_layer,dropout, num_head,num_samp): 8 | super(TVA_block_att, self).__init__() 9 | self.num_lay = num_layer 10 | self.Time_att = Time_att(Input_len,dropout,num_head) 11 | self.space_att = space_att2(Input_len,num_id, dropout, num_head) 12 | self.cross_att = cross_att(Input_len,dropout,num_head) 13 | self.dropout = nn.Dropout(dropout) 14 | self.linear = nn.Conv2d(in_channels=Input_len, out_channels=Input_len, kernel_size=(num_samp,1)) 15 | def forward(self, x): 16 | 17 | for i in range(self.num_lay): 18 | 19 | x = self.cross_att(self.Time_att(x),self.space_att(x)) 20 | 21 | x = self.linear(x.transpose(-3,-1)) 22 | x = x.squeeze(-2) 23 | return x.transpose(-2,-1) 24 | 25 | ### temporal attention 26 | class Time_att(nn.Module): 27 | def __init__(self, dim_input,dropout,num_head): 28 | super(Time_att, self).__init__() 29 | self.query = nn.Conv2d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 30 | self.key = nn.Conv2d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 31 | self.value = nn.Conv2d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 32 | self.laynorm = nn.LayerNorm([dim_input]) 33 | self.softmax = nn.Softmax(dim=-1) 34 | self.num_head = num_head 35 | self.dropout = nn.Dropout(dropout) 36 | self.output = nn.Linear(num_head,1) 37 | def forward(self, x): 38 | x = x.transpose(-3, -1) 39 | result = 0.0 40 | for i in range(self.num_head): 41 | q = self.dropout(self.query(x)).transpose(-3, -1) 42 | k = self.dropout(self.key(x)).transpose(-3, -1) 43 | k = k.transpose(-2, -1) 44 | v = self.dropout(self.value(x)).transpose(-3, -1) 45 | kd = torch.sqrt(torch.tensor(k.shape[-1]).to(torch.float32)/self.num_head) 46 | line = self.dropout(self.softmax(q @ k / kd)) @ v 47 | if i < 1: 48 | result = line.unsqueeze(-1) 49 | else: 50 | result = torch.cat([result,line.unsqueeze(-1)],dim=-1) 51 | result = self.output(result) 52 | result = result.squeeze(-1) 53 | x = x.transpose(-3, -1) + result 54 | x = self.laynorm(x) 55 | return x 56 | 57 | ### space_attention 58 | class space_att(nn.Module): 59 | def __init__(self, Input_len, dim_input,dropout,num_head): 60 | super(space_att, self).__init__() 61 | self.query = nn.Conv2d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 62 | self.key = nn.Conv2d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 63 | self.value = nn.Conv2d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 64 | self.softmax = nn.Softmax(dim=-1) 65 | self.num_head = num_head 66 | self.linear1 = nn.Linear(num_head, 1) 67 | self.dropout = nn.Dropout(dropout) 68 | def forward(self, x): 69 | 70 | result = 0.0 71 | for i in range(self.num_head): 72 | q = self.dropout(self.query(x)).transpose(-3, -1) 73 | k = self.dropout(self.key(x)).transpose(-3, -1) 74 | k = k.transpose(-2, -1) 75 | v = self.dropout(self.value(x)).transpose(-3, -1) 76 | kd = torch.sqrt(torch.tensor(k.shape[-1]).to(torch.float32)/self.num_head) 77 | 78 | line = self.dropout(self.softmax(q@k/kd))@ v 79 | if i < 1: 80 | result = line.unsqueeze(-1) 81 | else: 82 | result = torch.cat([result,line.unsqueeze(-1)],dim=-1) 83 | result = self.linear1(result) 84 | result = result.squeeze(-1) 85 | result = result.transpose(1, 3) 86 | return result 87 | 88 | ### space_attention2 89 | class space_att2(nn.Module): 90 | def __init__(self, Input_len, dim_input,dropout,num_head): 91 | super(space_att2, self).__init__() 92 | self.query = nn.Linear(dim_input, dim_input) 93 | self.key = nn.Linear(dim_input, dim_input) 94 | self.value = nn.Linear(dim_input, dim_input) 95 | self.softmax = nn.Softmax(dim=-1) 96 | self.num_head = num_head 97 | self.linear1 = nn.Linear(num_head, 1) 98 | self.dropout = nn.Dropout(dropout) 99 | def forward(self, x): 100 | 101 | x = x.transpose(1, 3) 102 | result = 0.0 103 | q = self.dropout(self.query(x)) 104 | k = self.dropout(self.key(x)) 105 | k = k.transpose(-2, -1) 106 | v = self.dropout(self.value(x)) 107 | kd = torch.sqrt(torch.tensor(k.shape[-1]).to(torch.float32) / self.num_head) 108 | 109 | for i in range(self.num_head): 110 | 111 | line = self.dropout(self.softmax(q@k/kd))@ v 112 | if i < 1: 113 | result = line.unsqueeze(-1) 114 | else: 115 | result = torch.cat([result,line.unsqueeze(-1)],dim=-1) 116 | result = self.linear1(result) 117 | result = result.squeeze(-1) 118 | result = result.transpose(1, 3) 119 | return result 120 | 121 | ### cross attention 122 | class cross_att(nn.Module): 123 | def __init__(self, dim_input,dropout,num_head): 124 | super(cross_att, self).__init__() 125 | self.query = nn.Conv2d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 126 | self.key = nn.Conv2d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 127 | self.value = nn.Conv2d(in_channels=dim_input,out_channels=dim_input,kernel_size=1) 128 | self.laynorm = nn.LayerNorm([dim_input]) 129 | self.softmax = nn.Softmax(dim=-1) 130 | self.num_head = num_head 131 | self.dropout = nn.Dropout(dropout) 132 | self.output = nn.Linear(num_head,1) 133 | def forward(self, x, x2): 134 | x = x.transpose(-3, -1) 135 | x2 = x2.transpose(-3, -1) 136 | result = 0.0 137 | for i in range(self.num_head): 138 | q = self.dropout(self.query(x2)).transpose(-3, -1) 139 | k = self.dropout(self.key(x)).transpose(-3, -1) 140 | k = k.transpose(-2, -1) 141 | v = self.dropout(self.value(x)).transpose(-3, -1) 142 | 143 | kd = torch.sqrt(torch.tensor(k.shape[-1]).to(torch.float32)/self.num_head) 144 | line = self.dropout(self.softmax(q @ k / kd)) @ v 145 | if i < 1: 146 | result = line.unsqueeze(-1) 147 | else: 148 | result = torch.cat([result,line.unsqueeze(-1)],dim=-1) 149 | result = self.output(result) 150 | result = result.squeeze(-1) 151 | x = x.transpose(-3, -1) + result 152 | x = self.laynorm(x) 153 | return x --------------------------------------------------------------------------------