├── Hyperparameters.py ├── LICENSE ├── README.md ├── data_fac.py ├── frameworks.py ├── layers.py ├── operations.py ├── pc ├── fig10.jpg └── framework.png ├── test.py ├── train.py └── utils_.py /Hyperparameters.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/6/9 11:21 2 | # @Author : Yanjie WEN 3 | # @Institution : CSU & BUCEA 4 | # @IDE : Pycharm 5 | # @FileName : Hyperparameters 6 | # @Project Name :metacode 7 | 8 | #Train your own data, Liver Explosion Parameters are necessary 9 | class Hyperparameters: 10 | #data 11 | data_04 = './data/PEMS04' 12 | data_08 = './data/PEMS08' 13 | data_nyc = './data/NYC' 14 | pkl_04 = 'PEM04.pkl' 15 | pkl_08 = 'PEM08.pkl' 16 | NYC_ = 'NYC.pkl' 17 | batch_size = 16#memory 18 | input_len = 12#inputting time sequence length 19 | output_len = 12#outputting time sequence length 20 | #graph embedding 21 | d_node = 64 22 | walk_len = 16 23 | num_walks = 100 24 | p = 0.3 #pems08->0.3.NYC->1 25 | q = 0.7 #if q<1 DFS; q>1 BFS; q=1 deepwalk pems08->0.7.NYC->2 26 | workers = 4 27 | eplision=0.1 28 | time_eplision = 0.3 29 | 30 | #model 31 | num_units =32#memory 32 | out_units = 1 33 | if_te = True#temproal embedding 34 | if_se = True#spatio embedding 35 | if_ste = True # spatiol&temproal embedding 36 | if_meta= True#meta learning 37 | drop_rate = 0.1 38 | num_heads=4#memory 39 | num_begin_blocks=1#memory 40 | num_medium_blocks =1#memory 41 | num_end_blocks = 1#memory 42 | #train 43 | if_onlydist = False 44 | if_onlytimesim=False 45 | if_val = True 46 | lr = 0.001 47 | logdir = 'logdir' 48 | num_epochs = 100 49 | init_val = 50 50 | ckpt_path = "./checkpoints/train" 51 | #test 52 | step_index = [2,5,8,11]#3,6,9,12 53 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 YanJieWen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # STGMT-Tensorflow2-implementation 2 | Traffic prediction based on -temporal guided multi graph Sandwich-Transformer(STGMT) 3 | 4 | 5 | ## Contents 6 | 7 | - [Background](#background) 8 | - [Preliminary](#preliminary) 9 | - [Dataset](#dataset) 10 | - [Weight](#weight) 11 | - [Training](#training) 12 | - [Testing](#testing) 13 | - [Results](#results) 14 | - [Contributing](#contributing) 15 | - [License](#license) 16 | 17 | ## Background 18 | 19 | The ability of spatial-temporal traffic prediction is crucial for urban computing, traffic management and future autonomous driving. In this paper, a novel spatial-temporal guided multi graph Sandwich-Transformer model (STGMT) is proposed, which is composed of the encoder, decoder and attention mechanism modules. The three modules are responsible for feature extraction of historical traffic data, autoregressive prediction and capture the features of spatial-temporal dimension respectively. Compared to the original transformer framework, the spatial-temporal embedding layer's output is introduced to guide the attention mechanism through meta learning which considers the heterogeneity of spatial nodes and temporal nodes. The temporal and spatial features are encoded through Time2Vec(T2V) and Node2Vec(N2V), and coupled into spatial-temporal embedding blocks. In addition, the multi graph is adopted to perform multi-head spatial self attention(MSA). Finally, the attention module and the feed forward layer are recombined to form the Sandwich-Transformer. 20 | 21 | ## Preliminary 22 | Before entering this project, you may need to configure the environment based on `Tensorflow2.x-gpu`. 23 | ``` 24 | !pip install node2vec 25 | ``` 26 | 27 | ### Dataset 28 | 29 | If you want to run this project, please download the datasets and weight file from the [Google](https://drive.google.com/drive/folders/1SiCIIiJ9aejYxDNXSVNzly7S-fZ8tGaW?usp=sharing). Then put the `checkpoints_NYC` and `checkpoints_pems08` into the project as named 'checkpoints' .After some tossing, you can run [data_fac.py](data_fac.py) to generate data files in `pkl format` for your training and testing, which may be a long wait. The `pkl flie` consists of 5 parts->`traind data`, `validation data`, `test data`, `multi graph`, `node2vec results`, and `inverse_transform scalar ` 30 | 31 | ### Weight 32 | If you just want to inference and not train your own datasets, you can modify any dataset and name it `checkpoints`, for example `checkpoints_pems08->checkpoints` 33 | 34 | 35 | ## Training 36 | The backbone STGMT 37 | ![image](pc/framework.png) 38 | 39 | The [operations.py](operations.py),[layer.py](layer.py) and [framework.py](framework.py) are the most important componets in this project. Moerover, You can come up with some innovative and great ideas and you can also can change the hyperparmetes in the [Hyperparameters.py](Hyperparameters.py) if you like .Before train your own datasets, you can just change the [train.py](train.py), `line 18` you can change your datasets path from [Hyperparameters.py](Hyperparameters.py), `line 52`, l1 loss is used. 40 | So you can finally train the model by running the following command: 41 | ``` 42 | python train.py 43 | ``` 44 | You will get a new file of your own trained weights saved in `checkpoints` folders.Don't worry about getting an error, even if there are weight files in the folder, they will be overwritten during training. CheckpointManager in the code can guarantee continuous training or future training `line 62` to `line 68`. 45 | 46 | ## Testing 47 | If you only want to inferrence on our dataset, it doesn't matter. Take the dataset in New York as an example, PEMS08 performs the same operation 48 | The [test.py](test.py) is the kernel, before testing, the operation as follows 49 | ``` 50 | change the data path-> line 19 51 | change error path and compare path -> line78, line 79 52 | python test.py 53 | ``` 54 | We provide three metrics: `MAE`, `RMSE`, and `MAPE` 55 | 56 | In the end, the terminate will show the results of `3,6,9,12` steps errors and average errors of each steps. Three tables will saved into your project `gap.csv `, `pred.csv`, and `ana.xlsx` 57 | 58 | 59 | ## Results 60 | The result of the NYC prediction: 61 | 62 | 63 | 64 | ![image](pc/fig10.jpg) 65 | 66 | More details please see the paper! 67 | 68 | ## Contributing 69 | 70 | 71 | At last, thank you very much for the contribution of the co-author in the article, and also thank my girlfriend for giving me the courage to pursue for a Ph.d. 72 | 73 | ## License 74 | 75 | [MIT](LICENSE) © YanjieWen 76 | 77 | -------------------------------------------------------------------------------- /data_fac.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/6/9 11:23 2 | # @Author : Yanjie WEN 3 | # @Institution : CSU & BUCEA 4 | # @IDE : Pycharm 5 | # @FileName : data_fac 6 | # @Project Name :metacode 7 | 8 | from Hyperparameters import Hyperparameters as hp 9 | import numpy as np 10 | import pickle 11 | import networkx as nx 12 | from node2vec import Node2Vec 13 | import pandas as pd 14 | import matplotlib.pyplot as plt 15 | from sklearn.preprocessing import MinMaxScaler 16 | import glob 17 | import os 18 | from utils_ import * 19 | 20 | 21 | np.random.seed(42) 22 | def get_data_path(dir,type): 23 | """ 24 | 25 | :param dir: rootdir path like './data/PEMS04' 26 | :param type: .npz,.csv 27 | :return: a full path 28 | """ 29 | files = [] 30 | for name in glob.glob('{}/*.{}'.format(dir,type)): 31 | files.append(name) 32 | return files[0] 33 | 34 | def read_flow_data(dir): 35 | """ 36 | 37 | :param dir: rootdir path like './data/PEMS04' 38 | :return: (points,sensors,flow) 39 | """ 40 | # pe = np.load(get_data_path(dir,'npz')) 41 | # datas = pe['data'][:,:,0,None]#for pkl 42 | datas = np.load(get_data_path(dir,'npy'))#for NYC 43 | return datas 44 | 45 | # def split_data_follow_day(datas,train_days,val_days): 46 | # """ 47 | # 48 | # :param datas: ((points,sensors,flow)) 49 | # :param train_days: value 50 | # :param val_days: value 51 | # :return: train,val,test data 52 | # """ 53 | # train_points = train_days*288#need to change 54 | # val_points = (train_days+val_days)*288#need to change 55 | # return [datas[:train_points,:,:],datas[train_points:val_points,:,:],datas[val_points:]] 56 | 57 | def get_datasets(datas): 58 | """ 59 | 60 | :param datas: train or val or test 61 | :return: train or val or test with shape(num_points,sequence_len,sensors,flow) 62 | """ 63 | sample_len = hp.input_len + hp.output_len 64 | datasets = [] 65 | for i in range(len(datas) - sample_len): 66 | datasets .append(datas[np.newaxis,i:i + sample_len]) 67 | data = np.concatenate(datasets,axis=0) 68 | return data 69 | def read_graph(dir): 70 | """ 71 | 72 | :param dir: a dirpath 73 | :return: dataframe contain statics graph 74 | """ 75 | df = pd.read_csv(get_data_path(dir,'csv')) 76 | return df 77 | def build_adjmatrix(df): 78 | """ 79 | 80 | :param df: dataframe 81 | :return: (N,N) adj matrix with the distance of the each node 82 | """ 83 | nodes_id = list(set(list(set(df['from'].values)) + list(set(df['to'].values)))) 84 | nodes_id=sorted(nodes_id) 85 | nodes_id = np.array(nodes_id) 86 | adj_init = np.zeros((len(nodes_id), len(nodes_id))) 87 | for edge_idx in range(len(df)): 88 | from_id = df.iloc[edge_idx]['from'] 89 | to_id = df.iloc[edge_idx]['to'] 90 | adj_init[np.where(nodes_id==from_id), np.where(nodes_id==to_id)] = df.iloc[edge_idx]['cost'] 91 | adj_matrix = adj_init 92 | return adj_matrix 93 | 94 | 95 | def data_preprocessing(dir): 96 | """ 97 | 98 | :param dir: like './data/PEMS04 99 | :param train_day: a value 100 | :param val_day: a value 101 | :return: [[train],[val],[test],[dism,tsimm],node2v,sca] 102 | """ 103 | #get train,val,test 104 | # dir = './data/PEMS04' 105 | datas =read_flow_data(dir) 106 | datas_,sca = minmaxsca(datas)#sca is needed 107 | # datasets = split_data_follow_day(datas_,train_day,val_day)#a list 108 | #dynamic datasets results 109 | all_datas = get_datasets(datas_) 110 | train, val, test = split_train_val_test(all_datas) 111 | all_datas = [train, val, test] 112 | #to get statics graph 113 | graph_df = read_graph(dir) 114 | adj_matrix = build_adjmatrix(graph_df) 115 | dis_sim = sim_distance(adj_matrix, graph_df) 116 | dis_norm = norm_adjmatrix(dis_sim) 117 | time_sim_norm = time_sim_matrix(datas) 118 | #to get node2vec 119 | graph = build_graph(get_data_path(dir,'csv')) 120 | node2vec = Node2Vec(graph, dimensions=hp.d_node, walk_length=hp.walk_len, num_walks=hp.num_walks, 121 | p=hp.p, q=hp.q, workers=hp.workers) 122 | model = node2vec.fit() 123 | node_vec = toarray([model.wv[str(node)] for node in sorted(graph.nodes())]) # (N,dc),if error you coudle model.wv due to the version of gensim 124 | #all data summary 125 | all_datas.append([dis_norm,time_sim_norm]) 126 | all_datas.append(node_vec) 127 | all_datas.append(sca) 128 | return all_datas 129 | 130 | def to_pkl(datas,pkl_path): 131 | with open(pkl_path,'wb') as f: 132 | pickle.dump(datas,f) 133 | 134 | 135 | def main(): 136 | all_datas = data_preprocessing(hp.data_nyc) 137 | to_pkl(all_datas,hp.NYC_) 138 | if __name__ == '__main__': 139 | main() 140 | 141 | 142 | -------------------------------------------------------------------------------- /frameworks.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/6/9 11:22 2 | # @Author : Yanjie WEN 3 | # @Institution : CSU & BUCEA 4 | # @IDE : Pycharm 5 | # @FileName : frameworks 6 | # @Project Name :metacode 7 | 8 | import tensorflow as tf 9 | from layers import * 10 | from Hyperparameters import Hyperparameters as hp 11 | 12 | class STGMT(tf.keras.Model): 13 | def __init__(self,d_model,num_heads,rate=hp.drop_rate): 14 | super().__init__() 15 | self.encoder = Encoder(d_model,num_heads,rate) 16 | self.decoder = Decoder(d_model,num_heads,rate) 17 | self.output_layer = tf.keras.layers.Dense(hp.out_units) 18 | def call(self,enc_inp,dec_inp,n2v,tms,if_te=True,if_se=True,if_ste=True,if_meta=True,training=True): 19 | enc_out,att_t_botte_e,att_s_botte_e,att_t_middle_e,att_s_middle_e,stes = self.encoder(enc_inp,n2v,tms,if_te, 20 | if_se,if_ste,if_meta,training) 21 | enc_memory = enc_out 22 | ste_enc = stes 23 | dec_inp = tf.cast(dec_inp,tf.float32) 24 | dec_inp = tf.concat([enc_inp[:,-1:,:,:],dec_inp[:,:-1,:,:]],axis=1)#right shifted 25 | dec_out, att_t_botte_d, att_s_botte_d, att_ti_botte_d, att_t_middle_d, att_s_middle_d, att_ti_middle_d = \ 26 | self.decoder(dec_inp,n2v,tms,enc_memory,ste_enc,if_te, 27 | if_se,if_ste,if_meta,training) 28 | final_output = self.output_layer(dec_out) 29 | enc_atts = [att_t_botte_e,att_s_botte_e,att_t_middle_e,att_s_middle_e] 30 | dec_atts = [att_t_botte_d, att_s_botte_d, att_ti_botte_d, att_t_middle_d, att_s_middle_d, att_ti_middle_d] 31 | return final_output,enc_atts,dec_atts 32 | 33 | 34 | 35 | 36 | class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): # lr setting 37 | def __init__(self, d_model, warmup_steps=4000): 38 | super(CustomSchedule, self).__init__() 39 | 40 | self.d_model = d_model 41 | self.d_model = tf.cast(self.d_model, tf.float32) 42 | 43 | self.warmup_steps = warmup_steps 44 | 45 | def __call__(self, step): 46 | arg1 = tf.math.rsqrt(step) 47 | arg2 = step * (self.warmup_steps ** -1.5) 48 | 49 | return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2) 50 | 51 | 52 | 53 | # def main(): 54 | # st_model = STGMT(16, 4) 55 | # enc_inp = tf.random.uniform((2, hp.input_len, 14, 1)) 56 | # dec_inp = tf.random.uniform((2, hp.output_len, 14, 1)) 57 | # n2v = tf.random.uniform((14, 64)) 58 | # tms = tf.random.uniform((2, 14, 14)) 59 | # final_output, enc_atts, dec_atts = st_model(enc_inp, dec_inp, n2v, tms) 60 | # print(final_output.shape) 61 | # print(st_model.summary()) 62 | # if __name__ == '__main__': 63 | # main() 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/6/9 11:22 2 | # @Author : Yanjie WEN 3 | # @Institution : CSU & BUCEA 4 | # @IDE : Pycharm 5 | # @FileName : layers 6 | # @Project Name :metacode 7 | import tensorflow as tf 8 | import numpy as np 9 | from Hyperparameters import Hyperparameters as hp 10 | from operations import * 11 | 12 | class Projection(tf.keras.layers.Layer): 13 | def __init__(self,d_model,**kwargs): 14 | ''' 15 | Non-Linear transformation of historical input 16 | Args: 17 | d_model: transform dimension 18 | **kwargs: None 19 | ''' 20 | super(Projection,self).__init__(**kwargs) 21 | self.d_model = d_model 22 | self.project = tf.keras.layers.Dense(self.d_model,activation=tf.keras.activations.relu,use_bias=False) 23 | def call(self,inp): 24 | return self.project(inp) 25 | 26 | class Temporal_embedding_layer(tf.keras.layers.Layer): 27 | def __init__(self,d_model,t_len,**kwargs): 28 | ''' 29 | Te with time2vec->(B,T,N,D) 30 | Args: 31 | d_model: a value 32 | t_len: a value 33 | **kwargs: 34 | ''' 35 | super(Temporal_embedding_layer,self).__init__(**kwargs) 36 | self.d_model = d_model 37 | self.t_len = t_len 38 | self.dense1 = tf.keras.layers.Dense(1,activation=tf.keras.activations.linear,use_bias=True) 39 | self.dense2 = tf.keras.layers.Dense(self.d_model,activation=tf.keras.activations.linear,use_bias=True) 40 | self.concat = tf.keras.layers.Concatenate(axis=-1) 41 | self.dense3 = tf.keras.layers.Dense(self.d_model,activation=tf.keras.activations.linear,use_bias=False) 42 | self.dense4 = tf.keras.layers.Dense(self.d_model,activation=tf.keras.activations.linear,use_bias=False) 43 | def call(self,inp): 44 | B,T,N,C = get_shape(inp) 45 | inp_ = tf.reshape(tf.transpose(inp,perm=[0,2,1,3]),(-1,T,C)) 46 | origin = self.dense1(inp_) 47 | sin_ = tf.math.sin(self.dense2(inp_)) 48 | te_out = self.concat((sin_,origin)) 49 | te_out = self.dense3(te_out) 50 | out_ = tf.transpose(tf.reshape(te_out, (-1, N, T, self.d_model)), perm=[0, 2, 1, 3]) 51 | out_ += positional_embedding(self.t_len, self.d_model) 52 | out_ = self.dense4(out_) 53 | return out_ 54 | 55 | class Spatial_embedding_layers(tf.keras.layers.Layer): 56 | def __init__(self,d_model,**kwargs): 57 | ''' 58 | node2vec after linear transform->(N,d) 59 | Args: 60 | d_model: a value 61 | **kwargs: 62 | ''' 63 | super(Spatial_embedding_layers,self).__init__(**kwargs) 64 | self.d_model= d_model 65 | self.dense = tf.keras.layers.Dense(self.d_model,use_bias=False) 66 | def call(self,inp): 67 | return self.dense(inp) 68 | 69 | 70 | class STE_embdding_layers(tf.keras.layers.Layer): 71 | def __init__(self,d_model,**kwargs): 72 | super(STE_embdding_layers,self).__init__(**kwargs) 73 | self.d_model = d_model 74 | self.dense = tf.keras.layers.Dense(self.d_model,use_bias=False) 75 | def call(self,te,se,if_te=True,if_se=True): 76 | ''' 77 | 78 | Args: 79 | te: tensor (b,t,n,d) 80 | se: tensor(n,d) 81 | if_te: a boolen 82 | if_se: a boolen 83 | 84 | Returns:ste (b,t,n,d) 85 | 86 | ''' 87 | B,T,N,D = get_shape(te) 88 | se_ = tf.tile(tf.expand_dims(tf.expand_dims(se,0),1),[B,T,1,1]) 89 | if if_te and if_se: 90 | ste = se_+te 91 | return self.dense(ste) 92 | elif if_te and not if_se: 93 | return te 94 | elif if_se and not if_te: 95 | return se_ 96 | 97 | class QKV_projection(tf.keras.layers.Layer): 98 | def __init__(self,d_model): 99 | super().__init__() 100 | self.d_model = d_model 101 | self.wq = tf.keras.layers.Dense(d_model) 102 | self.wk = tf.keras.layers.Dense(d_model) 103 | self.wv = tf.keras.layers.Dense(d_model) 104 | def call(self,quries,keys,values): 105 | ''' 106 | original projection for multi head attention 107 | Args: 108 | quries: b,tq,n,d 109 | keys: b,tk,n,d 110 | values: b,tk,n,d 111 | 112 | Returns: three same shape tensors ->(b,t,n,d) 113 | 114 | ''' 115 | return self.wq(quries),self.wk(keys),self.wv(values) 116 | 117 | class Meta_projection(tf.keras.layers.Layer): 118 | def __init__(self,d_model): 119 | super().__init__() 120 | self.d1 = d_model/2 121 | assert d_model%2==0,'Dimension should be even!' 122 | self.d2 = d_model 123 | self.meta1 = [tf.keras.layers.Dense(d_model,activation=tf.keras.activations.relu,use_bias=True) for _ in range(3)] 124 | self.meta2 = [tf.keras.layers.Dense(tf.math.square(d_model),activation=tf.keras.activations.linear,use_bias=True) for _ in range(3)] 125 | def reshape(self,out): 126 | b,t,n,d = get_shape(out) 127 | return tf.reshape(out,[-1,t,n,self.d2,self.d2]) 128 | def call(self,stq_q,stq_k,stq_v): 129 | ''' 130 | Generating the weight matrix by meta-learning with shaoe (b,t,n,d,d) 131 | Args: 132 | stq_q: b,t,n,d 133 | stq_k: b,t,n,d 134 | stq_v: b,t,n,d 135 | 136 | Returns:w->(b,t,n,d,d) 137 | 138 | ''' 139 | inp = [stq_q,stq_k,stq_v] 140 | w_list = [] 141 | for i in range(3): 142 | out1_ = self.meta1[i](inp[i]) 143 | out2_ = self.meta2[i](out1_) 144 | out_ = self.reshape(out2_) 145 | w_list.append(out_) 146 | return w_list 147 | 148 | class Multi_head_temporal_attention(tf.keras.layers.Layer): 149 | def __init__(self,d_model,num_heads): 150 | super().__init__() 151 | self.d_model = d_model 152 | self.num_heads = num_heads 153 | assert d_model%num_heads==0,"The d_model is incorrect,Adjusted to an integer multiple of heads" 154 | self.qkv_proj = QKV_projection(d_model) 155 | self.meta_pro = Meta_projection(d_model) 156 | self.dense = tf.keras.layers.Dense(self.d_model,use_bias=False) 157 | def split_head(self,x): 158 | return tf.concat(tf.split(x,self.num_heads,axis=-1),axis=0)#-->(-1*h,t,n,d/h) 159 | def temporal_reshape(self,x): 160 | b,t,_,d = get_shape(x) 161 | x_ = tf.reshape(tf.transpose(x,[0,2,1,3]),[-1,t,d]) 162 | return x_ 163 | def concat_head(self,x): 164 | return tf.concat(tf.split(x,self.num_heads,axis=0),axis=-1) 165 | def call(self,queries,keys,values,tms,ste_q,ste_k,ste_v,if_ste,if_meta,if_future,if_spatial=False): 166 | ''' 167 | Attention mechanisms in the temporal dimension 168 | Args: 169 | queries:b,t,n,d 170 | keys::b,t,n,d 171 | values::b,t,n,d 172 | ste_q:b,t,n,d 173 | ste_k:b,t,n,d 174 | ste_v:b,t,n,d 175 | if_ste:boolen 176 | if_meta:boolen 177 | if_future:boolen 178 | if_spatial:boolen 179 | 180 | Returns: (b,t,n,d),(bnh,tq,tk) 181 | 182 | ''' 183 | if if_ste and not if_meta: 184 | q,k,v = self.qkv_proj(tf.add(queries+ste_q),tf.add(keys+ste_k),tf.add(values+ste_v)) 185 | elif if_ste and if_meta: 186 | w_list = self.meta_pro(ste_q,ste_k,ste_v) 187 | wq,wk,wv = w_list 188 | q = meta_guide(queries,wq) 189 | k = meta_guide(keys,wk) 190 | v = meta_guide(values,wv) 191 | elif not if_ste: 192 | q,k,v = self.qkv_proj(queries,keys,values)#->(b,tk,n,d) 193 | b, t, n, d = get_shape(q) 194 | q,k,v = list(map(self.temporal_reshape,[q,k,v]))#->(bn,t,d) 195 | q_,k_,v_ = list(map(self.split_head,[q,k,v]))#->(bnh,t,d/h) 196 | out_ ,att = scale_dot_product_attention(q_,k_,v_,tm=tms,if_future=if_future,if_spatial=if_spatial)#-->(bnh,tq,d),(bnh,t,t) 197 | out_ = self.concat_head(out_)#->(bn,t,d) 198 | out_ = tf.transpose(tf.reshape(out_,(-1,n,t,d)),[0,2,1,3]) 199 | out_ = self.dense(out_) 200 | return out_,att 201 | 202 | class Multi_head_spatial_attention(tf.keras.layers.Layer): 203 | def __init__(self,d_model,num_heads): 204 | super().__init__() 205 | self.d_model = d_model 206 | self.num_heads = num_heads 207 | assert d_model%num_heads==0,"The d_model is incorrect,Adjusted to an integer multiple of heads" 208 | self.qkv_proj = QKV_projection(d_model) 209 | self.meta_pro = Meta_projection(d_model) 210 | self.dense = tf.keras.layers.Dense(self.d_model,use_bias=False) 211 | def split_head(self,x): 212 | return tf.concat(tf.split(x,self.num_heads,axis=-1),axis=0)#-->(-1*h,t,n,d/h) 213 | def spatial_reshape(self,x): 214 | b,_,n,d = get_shape(x) 215 | return tf.reshape(x,(-1,n,d)) 216 | def concat_head(self,x): 217 | return tf.concat(tf.split(x,self.num_heads,axis=0),axis=-1) 218 | def call(self,queries,keys,values,tms,ste_q,ste_k,ste_v,if_ste,if_meta,if_future=False,if_spatial=True): 219 | ''' 220 | Attention mechanisms in the spatial dimension 221 | Args: 222 | queries: b,t,n,d 223 | keys: b,t,n,d 224 | values: b,t,n,d 225 | tms: m,n,n 226 | ste_q: b,t,n,d 227 | ste_k: b,t,n,d 228 | ste_v: b,t,n,d 229 | if_ste: boolen 230 | if_meta: boolen 231 | if_future: boolen 232 | if_spatial: boolen 233 | 234 | Returns:(b,t,n,d),(bth,n,n) 235 | 236 | ''' 237 | outs = [] 238 | atts = [] 239 | tms = tf.cast(tms,tf.float32) 240 | b, t, n, d = get_shape(queries) 241 | for num in range(tms.shape[0]):#tms->(m,n,n) 242 | if if_ste and not if_meta: 243 | q,k,v = self.qkv_proj(tf.add(queries+ste_q),tf.add(keys+ste_k),tf.add(values+ste_v)) 244 | elif if_ste and if_meta: 245 | w_list = self.meta_pro(ste_q,ste_k,ste_v) 246 | wq,wk,wv = w_list 247 | q = meta_guide(queries,wq) 248 | k = meta_guide(keys,wk) 249 | v = meta_guide(values,wv) 250 | elif not if_ste: 251 | q,k,v = self.qkv_proj(queries,keys,values)#->(b,tk,n,d) 252 | q, k, v = list(map(self.spatial_reshape, [q, k, v])) # ->(bt,n,d) 253 | q_, k_, v_ = list(map(self.split_head, [q, k, v])) # ->(bth,n,d/h) 254 | out_, att = scale_dot_product_attention(q_, k_, v_, tm=tms[num], if_future=if_future, if_spatial=if_spatial) 255 | atts.append(att) 256 | out_ = self.concat_head(out_)#->(bt,n,d) 257 | outs.append(out_) 258 | out = tf.transpose(tf.convert_to_tensor(outs),[1,2,3,0])#->(m,bt,n,d)->(bt,n,d,m) 259 | out = tf.reshape(out,(-1,t,n,d*tms.shape[0]))#(b,t,n,dm) 260 | atts = tf.convert_to_tensor(atts)#->(m,bht,n,n) 261 | out = self.dense(out)#->(b,t,n,d) 262 | return out,atts 263 | 264 | class FFN(tf.keras.layers.Layer): 265 | def __init__(self,d_model): 266 | super().__init__() 267 | self.d1 = d_model/4 268 | assert d_model%4==0,"The d should be divisible by 4" 269 | self.d2 = d_model 270 | self.dense1 = tf.keras.layers.Dense(self.d1,activation=tf.keras.activations.relu,use_bias=False) 271 | self.dense2 = tf.keras.layers.Dense(self.d2,activation=tf.keras.activations.linear,use_bias=False) 272 | def call(self,inp): 273 | ''' 274 | dimension transformation 275 | Args: 276 | inp:(b,t,n,d) 277 | 278 | Returns:(b,t,n,d) 279 | 280 | ''' 281 | out_ = self.dense1(inp) 282 | out_ = self.dense2(out_) 283 | return out_ 284 | 285 | class Encoderlayer(tf.keras.layers.Layer): 286 | def __init__(self,d_model,num_heads,rate=hp.drop_rate,b_layers=hp.num_begin_blocks,m_layers=hp.num_medium_blocks,t_layers=hp.num_end_blocks): 287 | super().__init__() 288 | self.b_layers = b_layers 289 | self.m_layers = m_layers 290 | self.t_layers = t_layers 291 | #Bottom layer 292 | self.mtas1 = [Multi_head_temporal_attention(d_model,num_heads) for _ in range(b_layers)] 293 | self.dropouts1_0 = [tf.keras.layers.Dropout(rate) for _ in range(b_layers) ] 294 | self.layernorms1_0 = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(b_layers)] 295 | self.msas1 = [Multi_head_spatial_attention(d_model,num_heads) for _ in range(b_layers)] 296 | self.dropouts1_1 = [tf.keras.layers.Dropout(rate) for _ in range(b_layers)] 297 | self.layernorms1_1 = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(b_layers)] 298 | # Middle layer 299 | self.mtas2 = [Multi_head_temporal_attention(d_model, num_heads) for _ in range(m_layers)] 300 | self.dropouts2_0 = [tf.keras.layers.Dropout(rate) for _ in range(m_layers)] 301 | self.layernorms2_0 = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(m_layers)] 302 | self.msas2 = [Multi_head_spatial_attention(d_model, num_heads) for _ in range(m_layers)] 303 | self.dropouts2_1 = [tf.keras.layers.Dropout(rate) for _ in range(m_layers)] 304 | self.layernorms2_1 = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(m_layers)] 305 | self.ffns1 = [FFN(d_model) for _ in range(m_layers)] 306 | self.dropouts2_2 = [tf.keras.layers.Dropout(rate) for _ in range(m_layers)] 307 | self.layernorms2_2 = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(m_layers)] 308 | # Top layer 309 | self.ffns2 = [FFN(d_model) for _ in range(t_layers)] 310 | self.dropouts3_0 = [tf.keras.layers.Dropout(rate) for _ in range(t_layers)] 311 | self.layernorms3_0 = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(t_layers)] 312 | def call(self,queries,keys,values,tms,ste_q,ste_k,ste_v,if_ste,if_meta,training=True): 313 | # Bottom layer 314 | #(bnh,t,t) 315 | for i in range(self.b_layers): 316 | out_, att_t_botte_e = self.mtas1[i](queries,keys,values,None,ste_q,ste_k,ste_v,if_ste,if_meta,if_future=False,if_spatial=False) 317 | out_ = self.dropouts1_0[i](out_,training=training) 318 | out_ = self.layernorms1_0[i](queries+out_) 319 | #(m,bht,n,n) 320 | queries=keys=values=out_ 321 | out_, att_s_botte_e = self.msas1[i](queries,keys,values,tms, ste_q, ste_k, ste_v, if_ste, if_meta, if_future=False,if_spatial=True) 322 | out_ = self.dropouts1_1[i](out_, training=training) 323 | out_ = self.layernorms1_1[i](queries + out_) 324 | queries = keys = values = out_ 325 | # Middle layer 326 | for i in range(self.m_layers): 327 | out_, att_t_middle_e = self.mtas2[i](queries, keys, values,None, ste_q, ste_k, ste_v, if_ste, if_meta, if_future=False,if_spatial=False) 328 | out_ = self.dropouts2_0[i](out_, training=training) 329 | out_ = self.layernorms2_0[i](queries + out_) 330 | queries = keys = values = out_ 331 | out_, att_s_middle_e = self.msas2[i](queries, keys, values, tms, ste_q, ste_k, ste_v, if_ste, if_meta, if_future=False,if_spatial=True) 332 | out_ = self.dropouts2_1[i](out_, training=training) 333 | out1 = self.layernorms2_1[i](queries + out_) 334 | out_ = self.ffns1[i](out1) 335 | out_ = self.dropouts2_2[i](out_, training=training) 336 | out2 = self.layernorms2_2[i](out_+out1) 337 | queries = keys = values = out_ 338 | for i in range(self.t_layers): 339 | # Top layer 340 | out_ = self.ffns2[i](out2) 341 | out_ = self.dropouts3_0[i](out_, training=training) 342 | out_ = self.layernorms3_0[i](out_ + out2) 343 | out2 = out_ 344 | return out_,att_t_botte_e,att_s_botte_e,att_t_middle_e,att_s_middle_e#only the last layer are reversed 345 | 346 | class Encoder(tf.keras.layers.Layer): 347 | def __init__(self,d_model,num_heads,rate=hp.drop_rate,b_layers=hp.num_begin_blocks,m_layers=hp.num_medium_blocks,t_layers=hp.num_end_blocks): 348 | super().__init__() 349 | self.d_model = d_model 350 | self.num_heads = num_heads 351 | 352 | #embeeding layers 353 | self.project = Projection(d_model) 354 | self.te = Temporal_embedding_layer(d_model,hp.input_len) 355 | self.se = Spatial_embedding_layers(d_model) 356 | self.ste_embedding = STE_embdding_layers(d_model) 357 | #enc_layers 358 | self.encoder_layer = Encoderlayer(d_model,num_heads,rate=rate,b_layers=b_layers,m_layers=m_layers,t_layers=t_layers) 359 | 360 | def call(self,x,n2v,tms,if_te=True,if_se=True,if_ste=True,if_meta=True,training=True): 361 | enc_inp = self.project(x) 362 | te_inp = self.te(x) 363 | se_inp = self.se(n2v) 364 | stes = self.ste_embedding(te_inp,se_inp,if_te,if_se) 365 | ste_q = ste_k = ste_v=stes 366 | queries=keys=values = enc_inp 367 | enc_out,att_t_botte_e,att_s_botte_e,att_t_middle_e,att_s_middle_e = self.encoder_layer(queries,keys,values,tms,ste_q,ste_k,ste_v,if_ste,if_meta,training=training) 368 | return enc_out,att_t_botte_e,att_s_botte_e,att_t_middle_e,att_s_middle_e,stes 369 | 370 | 371 | class Decoderlayer(tf.keras.layers.Layer): 372 | def __init__(self,d_model,num_heads,rate=hp.drop_rate,b_layers=hp.num_begin_blocks,m_layers=hp.num_medium_blocks,t_layers=hp.num_end_blocks): 373 | super().__init__() 374 | self.b_layers = b_layers 375 | self.m_layers = m_layers 376 | self.t_layers = t_layers 377 | #Bottom layer 378 | self.mtas1 = [Multi_head_temporal_attention(d_model,num_heads) for _ in range(b_layers)] 379 | self.dropouts1_0 = [tf.keras.layers.Dropout(rate) for _ in range(b_layers) ] 380 | self.layernorms1_0 = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(b_layers)] 381 | self.msas1 = [Multi_head_spatial_attention(d_model,num_heads) for _ in range(b_layers)] 382 | self.dropouts1_1 = [tf.keras.layers.Dropout(rate) for _ in range(b_layers)] 383 | self.layernorms1_1 = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(b_layers)] 384 | self.mtias1 = [Multi_head_temporal_attention(d_model,num_heads) for _ in range(b_layers)] 385 | self.dropouts1_2 = [tf.keras.layers.Dropout(rate) for _ in range(b_layers)] 386 | self.layernorms1_2 = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(b_layers)] 387 | # Middle layer 388 | self.mtas2 = [Multi_head_temporal_attention(d_model, num_heads) for _ in range(m_layers)] 389 | self.dropouts2_0 = [tf.keras.layers.Dropout(rate) for _ in range(m_layers)] 390 | self.layernorms2_0 = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(m_layers)] 391 | self.msas2 = [Multi_head_spatial_attention(d_model, num_heads) for _ in range(m_layers)] 392 | self.dropouts2_1 = [tf.keras.layers.Dropout(rate) for _ in range(m_layers)] 393 | self.layernorms2_1 = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(m_layers)] 394 | self.mtias2 = [Multi_head_temporal_attention(d_model, num_heads) for _ in range(m_layers)] 395 | self.dropouts2_2 = [tf.keras.layers.Dropout(rate) for _ in range(m_layers)] 396 | self.layernorms2_2 = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(m_layers)] 397 | self.ffns1 = [FFN(d_model) for _ in range(m_layers)] 398 | self.dropouts2_3 = [tf.keras.layers.Dropout(rate) for _ in range(m_layers)] 399 | self.layernorms2_3 = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(m_layers)] 400 | # Top layer 401 | self.ffns2 = [FFN(d_model) for _ in range(t_layers)] 402 | self.dropouts3_0 = [tf.keras.layers.Dropout(rate) for _ in range(t_layers)] 403 | self.layernorms3_0 = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(t_layers)] 404 | def call(self,queries,keys,values,enc_memory,tms,ste_enc,ste_q,ste_k,ste_v,if_ste,if_meta,training=True): 405 | # Bottom layer 406 | #(bnh,t,t) 407 | for i in range(self.b_layers): 408 | out_, att_t_botte_d = self.mtas1[i](queries,keys,values,None,ste_q,ste_k,ste_v,if_ste,if_meta,if_future=True,if_spatial=False) 409 | out_ = self.dropouts1_0[i](out_,training=training) 410 | out_ = self.layernorms1_0[i](queries+out_) 411 | #(m,bht,n,n) 412 | queries=out_ 413 | out_, att_s_botte_d = self.msas1[i](out_,out_,out_,tms, ste_q, ste_k, ste_v, if_ste, if_meta, if_future=False,if_spatial=True) 414 | out_ = self.dropouts1_1[i](out_, training=training) 415 | out_ = self.layernorms1_1[i](queries + out_) 416 | queries = out_ 417 | out_, att_ti_botte_d = self.mtias1[i](out_, enc_memory, enc_memory, None, ste_q, ste_enc, ste_enc, if_ste, if_meta, 418 | if_future=False, if_spatial=False) 419 | out_ = self.dropouts1_2[i](out_, training=training) 420 | out_ = self.layernorms1_2[i](queries + out_) 421 | queries = keys = values = out_ 422 | # Middle layer 423 | for i in range(self.m_layers): 424 | out_, att_t_middle_d = self.mtas2[i](queries, keys, values,None, ste_q, ste_k, ste_v, if_ste, if_meta, if_future=True,if_spatial=False) 425 | out_ = self.dropouts2_0[i](out_, training=training) 426 | out_ = self.layernorms2_0[i](queries + out_) 427 | queries = out_ 428 | out_, att_s_middle_d = self.msas2[i](out_, out_,out_, tms, ste_q, ste_k, ste_v, if_ste, if_meta, if_future=False,if_spatial=True) 429 | out_ = self.dropouts2_1[i](out_, training=training) 430 | out_ = self.layernorms2_1[i](queries + out_) 431 | queries = out_ 432 | out_, att_ti_middle_d = self.mtias2[i](out_, enc_memory, enc_memory, None, ste_q, ste_enc, ste_enc, if_ste, 433 | if_meta,if_future=False, if_spatial=False) 434 | out_ = self.dropouts2_2[i](out_, training=training) 435 | out1 = self.layernorms2_2[i](queries + out_) 436 | out_ = self.ffns1[i](out1) 437 | out_ = self.dropouts2_3[i](out_, training=training) 438 | out2 = self.layernorms2_3[i](out_+out1) 439 | queries = keys = values = out_ 440 | for i in range(self.t_layers): 441 | # Top layer 442 | out_ = self.ffns2[i](out2) 443 | out_ = self.dropouts3_0[i](out_, training=training) 444 | out_ = self.layernorms3_0[i](out_ + out2) 445 | out2 = out_ 446 | return out_,att_t_botte_d,att_s_botte_d,att_ti_botte_d,att_t_middle_d,att_s_middle_d,att_ti_middle_d#only the last layer are reversed 447 | 448 | class Decoder(tf.keras.layers.Layer): 449 | def __init__(self,d_model,num_heads,rate=hp.drop_rate,b_layers=hp.num_begin_blocks,m_layers=hp.num_medium_blocks,t_layers=hp.num_end_blocks): 450 | super().__init__() 451 | self.d_model = d_model 452 | self.num_heads = num_heads 453 | 454 | #embeeding layers 455 | self.project = Projection(d_model) 456 | self.te = Temporal_embedding_layer(d_model,hp.output_len) 457 | self.se = Spatial_embedding_layers(d_model) 458 | self.ste_embedding = STE_embdding_layers(d_model) 459 | #dec_layers 460 | self.decoder_layer = Decoderlayer(d_model,num_heads,rate=rate,b_layers=b_layers,m_layers=m_layers,t_layers=t_layers) 461 | 462 | def call(self,x,n2v,tms,enc_memory,ste_enc,if_te=True, 463 | if_se=True,if_ste=True,if_meta=True,training=True): 464 | dec_inp = self.project(x) 465 | te_inp = self.te(x) 466 | se_inp = self.se(n2v) 467 | stes = self.ste_embedding(te_inp,se_inp,if_te,if_se) 468 | ste_q = ste_k = ste_v=stes 469 | queries=keys=values = dec_inp 470 | dec_out,att_t_botte_d,att_s_botte_d,att_ti_botte_d,att_t_middle_d,att_s_middle_d,att_ti_middle_d = self.decoder_layer(queries,keys,values,enc_memory,tms, 471 | ste_enc,ste_q,ste_k,ste_v,if_ste,if_meta,training=training) 472 | return dec_out,att_t_botte_d,att_s_botte_d,att_ti_botte_d,att_t_middle_d,att_s_middle_d,att_ti_middle_d 473 | 474 | -------------------------------------------------------------------------------- /operations.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/11/20 16:10 2 | # @Author : Yanjie WEN 3 | # @Institution : CSU & BUCEA 4 | # @IDE : Pycharm 5 | # @FileName : operations 6 | # @Project Name :STGMT-Tensorflow-implementation-master 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | 12 | def get_shape(tensor): 13 | ''' 14 | obtain the shape of tensor 15 | Args: 16 | tensor: a tensor (b,t,n,d) 17 | 18 | Returns:[b,t,n,d] 19 | 20 | ''' 21 | return tensor.get_shape().as_list() 22 | 23 | def positional_embedding(len_t,d_model): 24 | ''' 25 | positional embedding for temporal ignore spatial 26 | Args: 27 | len_t: the length of temporal 28 | d_model: the last dimension 29 | 30 | Returns: a tensor (1,t,1,d) 31 | 32 | ''' 33 | PE = np.array([[pos / np.power(10000, (i-i%2)/d_model) for i in range(d_model)] for pos in range(len_t)]) 34 | PE[:, 0::2] = np.sin(PE[:, 0::2]) 35 | PE[:, 1::2] = np.cos(PE[:, 1::2]) 36 | PE = PE* d_model ** 0.5 37 | PE = PE[np.newaxis, :, np.newaxis, :] 38 | outputs = tf.cast(PE, dtype=tf.float32) 39 | return outputs 40 | 41 | def future_mask(att): 42 | ''' 43 | Future masking in the first decoder of the time dimension to prevent future information leakage 44 | Args: 45 | att:(bnh,t_q,t_k) 46 | 47 | Returns:(bnh,t_q,t_k) after future_mask only att value or -inf 48 | 49 | ''' 50 | padding_num = -2**32+1 51 | z,t_q,t_k = get_shape(att) 52 | diag_ = tf.ones_like(att[0,:,:]) 53 | tril = tf.linalg.band_part(diag_,-1,0) 54 | future_masks = tf.tile(tf.expand_dims(tril,0),[z,1,1]) 55 | paddings = tf.ones_like(future_masks)*padding_num 56 | return tf.where(tf.equal(future_masks,0),paddings,att) 57 | 58 | def spatial_mask(tm): 59 | ''' 60 | The point mask for points in the transfer matrix that are 0 is 61 | Args: 62 | tm:(n,n) 63 | 64 | Returns:(1,n,n) only 0 or -inf 65 | 66 | ''' 67 | padding_num = -2 ** 32 + 1 68 | tm_mask = tf.zeros_like(tm) 69 | padding = tf.ones_like(tm_mask)*padding_num 70 | tm_mask = tf.where(tf.equal(tm,0),padding,tm_mask) 71 | return tf.expand_dims(tm_mask,0) 72 | 73 | def meta_guide(inp,meta_ste): 74 | ''' 75 | STE Guided learning to address spatio-temporal heterogeneity 76 | Args: 77 | inp:b,t,n,d, 78 | meta_ste:b,t,n,d,d 79 | 80 | Returns: 81 | 82 | ''' 83 | inp_ = tf.expand_dims(inp,-1)#->(b,t,n,d,1) 84 | return tf.squeeze(tf.einsum('btndd,btndi->btndi',meta_ste,inp_),axis=-1) 85 | 86 | def scale_dot_product_attention(q,k,v,tm,if_future,if_spatial): 87 | ''' 88 | SDPA for multi head attention 89 | Args: 90 | q: (-1,q,d/h) 91 | k: (-1,k,d/h) 92 | v: (-1,k,d/h) 93 | tm: (n,n) 94 | if_future:boolen 95 | if_spatial:boolen 96 | 97 | Returns: (-1,q,d/h),(bnh,tq,tk) 98 | 99 | ''' 100 | mat_qk = tf.matmul(q,tf.transpose(k,[0,2,1])) 101 | dk = tf.cast(tf.shape(k)[-1], tf.float32) 102 | scaled_attention_logits = mat_qk / tf.math.sqrt(dk)#-->(bnh,tq,tk) 103 | if if_future and tm is None: 104 | att = future_mask(scaled_attention_logits) 105 | att = tf.nn.softmax(att) 106 | elif if_spatial: 107 | tm_ = tf.expand_dims(tm,0) 108 | tm_mask = spatial_mask(tm) 109 | att = tf.add(scaled_attention_logits,tm_mask) 110 | att = tf.nn.softmax(att)*tm_ 111 | else: 112 | att = scaled_attention_logits#-->(-1,q,k) 113 | att = tf.nn.softmax(att) 114 | # -->(-1,q,d/h) 115 | return tf.matmul(att,v),att 116 | -------------------------------------------------------------------------------- /pc/fig10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YanJieWen/STGMT-Tensorflow-implementation/586e0a7a788a774892bf783accc86579e7fb8419/pc/fig10.jpg -------------------------------------------------------------------------------- /pc/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YanJieWen/STGMT-Tensorflow-implementation/586e0a7a788a774892bf783accc86579e7fb8419/pc/framework.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/6/9 11:22 2 | # @Author : Yanjie WEN 3 | # @Institution : CSU & BUCEA 4 | # @IDE : Pycharm 5 | # @FileName : test 6 | # @Project Name :metacode 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | from Hyperparameters import Hyperparameters as hp 11 | from utils_ import * 12 | import pickle 13 | import random 14 | import pandas as pd 15 | 16 | from frameworks import * 17 | 18 | #step1:get test data 19 | data = read_pkl(hp.pkl_08) 20 | test = data[2] 21 | tms = toarray(data[3]) 22 | node2vec = data[4] 23 | scalar = data[5] 24 | test_dataset = tf.data.Dataset.from_tensor_slices((test[:,:12,:,:],test[:,-12:,:,:])) 25 | test_dataset = test_dataset.shuffle(buffer_size=1000).batch(hp.batch_size) 26 | test_dataset = test_dataset.prefetch(tf.data.experimental.AUTOTUNE) 27 | #load model sturctre 28 | stgmt = STGMT(d_model=hp.num_units,num_heads=hp.num_heads) 29 | #load weights with ckpt form 30 | checkpoint = tf.train.Checkpoint(stgmt=stgmt) 31 | checkpoint.restore(tf.train.latest_checkpoint(hp.ckpt_path)).expect_partial() 32 | print('The latest weight has been restored!') 33 | def evaluate(enc_inp, n2v, tms,model): 34 | b, t, n, d = get_shape(enc_inp) 35 | dec_inp = tf.ones((b, hp.output_len, n, d)) 36 | for step in range(dec_inp.shape[1]): 37 | outs, _, _ = model(enc_inp, dec_inp, n2v, tms, if_te=True, if_se=True, if_ste=True, 38 | if_meta=True, training=False) 39 | outs_ = outs.numpy() 40 | dec_inp_ = dec_inp.numpy() 41 | dec_inp_[:, step, :, :] = outs_[:, step, :, :] 42 | dec_inp = tf.convert_to_tensor(dec_inp_, tf.float32) 43 | return dec_inp # ->(b,t,n,d) 44 | #eval 45 | mae = [] 46 | rmse = [] 47 | mape = [] 48 | gts = [] 49 | preds = [] 50 | # errors = [] 51 | count=0 52 | for batch,(enc_inp,gt) in enumerate(test_dataset): 53 | b,t,n,d = get_shape(gt) 54 | pred_ = evaluate(enc_inp,node2vec,tms,stgmt) 55 | pred = scalar.inverse_transform(tf.reshape(pred_,[-1,1]))#btn,d 56 | gt = scalar.inverse_transform(tf.reshape(gt,[-1,1])) 57 | gt0 = np.reshape(np.transpose(np.reshape(gt,(-1,hp.output_len,n,d)),[0,2,1,3]),(-1,hp.output_len)) 58 | pred0 = np.reshape(np.transpose(np.reshape(pred,(-1,hp.output_len,n,d)),[0,2,1,3]),(-1,hp.output_len)) 59 | preds.append(pred0) 60 | gts.append(gt0) 61 | mae.append(cal_mae(gt0,pred0)[np.newaxis,:]) 62 | rmse.append(cal_rmse(gt0,pred0)[np.newaxis,:]) 63 | mape.append(toarray(cal_mape(gt0,pred0))[np.newaxis,:]) 64 | 65 | mae = np.concatenate(mae,axis=0) 66 | rmse = np.concatenate(rmse,axis=0) 67 | mape = np.concatenate(mape,axis=0) 68 | gts = np.concatenate(gts,axis=0) 69 | preds = np.concatenate(preds,axis=0) 70 | print(np.mean(mae,axis=0)) 71 | print('*'*30) 72 | print(np.mean(rmse,axis=0)) 73 | print('*'*30) 74 | print(np.mean(mape,axis=0)) 75 | print('*'*30) 76 | errors = [np.mean(mae,axis=0),np.mean(rmse,axis=0),np.mean(mape,axis=0)] 77 | df = pd.DataFrame(errors) 78 | df.to_csv('./gap.csv') 79 | writer = pd.ExcelWriter('ana.xlsx') 80 | df_gt = pd.DataFrame(gts[:1000,:]) 81 | df_pred = pd.DataFrame(preds[:1000,:]) 82 | df_gt.to_excel(writer,'sheet1') 83 | df_pred.to_excel(writer,'sheet2') 84 | writer.save() 85 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/6/9 11:22 2 | # @Author : Yanjie WEN 3 | # @Institution : CSU & BUCEA 4 | # @IDE : Pycharm 5 | # @FileName : train 6 | # @Project Name :metacode 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | from Hyperparameters import Hyperparameters as hp 11 | from utils_ import * 12 | from frameworks import * 13 | from operations import * 14 | import time 15 | import random 16 | 17 | #step1:Prepare dataset pipline 18 | data = read_pkl(hp.pkl_08) 19 | train= data[0] 20 | val = data[1] 21 | test = data[2] 22 | tms = toarray(data[3]) 23 | node2vec = data[4] 24 | scalar = data[5] 25 | train_dataset = tf.data.Dataset.from_tensor_slices((train[:,:hp.input_len,:,:],train[:,-hp.input_len:,:,:])) 26 | val_dataset = tf.data.Dataset.from_tensor_slices((val[:,:hp.input_len,:,:],val[:,-hp.input_len:,:,:])) 27 | train_dataset = train_dataset.cache()#for train dataset 28 | train_dataset = train_dataset.shuffle(buffer_size=1000).batch(hp.batch_size) 29 | train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE) 30 | val_dataset = val_dataset.shuffle(buffer_size=1000).batch(hp.batch_size) 31 | val_dataset = val_dataset.prefetch(tf.data.experimental.AUTOTUNE) 32 | #step2:define loss object 33 | loss_object = tf.keras.losses.MeanAbsoluteError(name='mean_absolute_error') 34 | def loss_fn(pred,gt): 35 | return tf.reduce_mean(loss_object(gt,pred)) 36 | #step3:define metric object 37 | # def rsqure(gt,pred): 38 | # fenzi = tf.reduce_sum(tf.math.square(pred-gt)) 39 | # fenmu = tf.reduce_sum(tf.math.square(gt-tf.reduce_mean(gt))) 40 | # return tf.cast((1-fenzi/fenmu),tf.float32) 41 | # class R_square(tf.keras.metrics.Metric): 42 | # def __init__(self): 43 | # super().__init__() 44 | # self.total = self.add_weight(name='total', dtype=tf.int32, initializer=tf.zeros_initializer()) 45 | # self.count = self.add_weight(name='count', dtype=tf.int32, initializer=tf.zeros_initializer()) 46 | # def update_state(self,gt,pred): 47 | # values = rsqure(gt,pred) 48 | # self.total.assign_add(tf.reduce_sum(values)) 49 | # self.count.assign_add(tf.cast(tf.size(gt)),tf.float32) 50 | # def result(self): 51 | # return self.total/self.count 52 | train_loss = tf.keras.metrics.MeanSquaredError(name='train_loss') 53 | #step4:Define optimizer 54 | learning_rate = CustomSchedule(hp.num_units) 55 | # optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 56 | # epsilon=1e-9,clipvalue=1.0) 57 | optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 58 | epsilon=1e-9) 59 | #step5:Define training and checkpoint 60 | stgmt = STGMT(d_model=hp.num_units,num_heads=hp.num_heads) 61 | checkpoint_path = hp.ckpt_path 62 | ckpt = tf.train.Checkpoint(stgmt=stgmt, 63 | optimizer=optimizer) 64 | ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5) 65 | # 如果检查点存在,则恢复最新的检查点。 66 | if ckpt_manager.latest_checkpoint: 67 | ckpt.restore(ckpt_manager.latest_checkpoint) 68 | print ('Latest checkpoint restored!!')#if future training 69 | #step6:Define training step 70 | # train_step_signature = [ 71 | # tf.TensorSpec(shape=(None, None,None,None), dtype=tf.float32), 72 | # tf.TensorSpec(shape=(None, None,None,None), dtype=tf.float32), 73 | # ] 74 | @tf.function 75 | def train_step(enc_inp,dec_inp,n2v, tms): 76 | with tf.GradientTape() as tape: 77 | final_output, enc_atts, dec_atts = stgmt(enc_inp,dec_inp,n2v,tms,if_te=True,if_se=True,if_ste=True,if_meta=True,training=True) 78 | loss = loss_fn(final_output,dec_inp) 79 | gradients = tape.gradient(loss, stgmt.trainable_variables) 80 | optimizer.apply_gradients(zip(gradients, stgmt.trainable_variables)) 81 | train_loss(dec_inp,final_output) 82 | # train_accuarcy() 83 | #eval function 84 | def evaluate(enc_inp,n2v,tms): 85 | b,t,n,d = get_shape(enc_inp) 86 | dec_inp = tf.zeros((b,hp.output_len,n,d)) 87 | for step in range(dec_inp.shape[1]): 88 | outs, _, _ = stgmt(enc_inp, dec_inp, n2v, tms, if_te=True, if_se=True, if_ste=True, 89 | if_meta=True, training=False) 90 | outs_ = outs.numpy() 91 | dec_inp_ = dec_inp.numpy() 92 | dec_inp_[:, step, :, :] = outs_[:, step, :, :] 93 | dec_inp = tf.convert_to_tensor(dec_inp_,tf.float32) 94 | # dec_inp_list = tf.unstack(dec_inp) 95 | # outs_list = tf.unstack(outs) 96 | # dec_inp_list[:,step,:,:]=outs_list[:,step,:,:] 97 | # dec_inp = tf.stack(dec_inp_list) 98 | return dec_inp#->(b,t,n,d) 99 | 100 | 101 | #train processing 102 | init_val = hp.init_val 103 | for epoch in range(hp.num_epochs): 104 | start = time.time() 105 | train_loss.reset_states() 106 | # train_accuarcy.update_state(gt,pred) 107 | for (batch,(enc_inp,dec_inp)) in enumerate(train_dataset): 108 | train_step(enc_inp,dec_inp,node2vec, tms) 109 | if batch%50==0: 110 | print('Epoch {} /Batch {} Loss {:.4f}'.format( 111 | epoch + 1, batch, train_loss.result())) 112 | if (epoch+1)%5==0 and batch%50==0: 113 | if hp.if_val:#if use validation data 114 | count = 0 115 | errors = [] 116 | while count<10: 117 | enc_inp,gt = next(iter(val_dataset)) 118 | pred_ = evaluate(enc_inp,node2vec,tms) 119 | pred = scalar.inverse_transform(tf.reshape(pred_,[-1,hp.out_units])) 120 | gt = scalar.inverse_transform(tf.reshape(gt,[-1,hp.out_units])) 121 | mae = tf.reduce_mean(tf.abs(pred-gt)) 122 | errors.append(mae) 123 | count+=1 124 | mae = tf.reduce_mean(tf.convert_to_tensor(errors,tf.float32)) 125 | if mae.numpy()MAE is {:.4f}'.format(epoch + 1, 128 | ckpt_save_path,mae.numpy())) 129 | init_val = mae.numpy() 130 | 131 | else: 132 | ckpt_save_path = ckpt_manager.save() 133 | print('Saving checkpoint for epoch {} at {}'.format(epoch + 1, 134 | ckpt_save_path)) 135 | print('Time taken for 1 epoch: {} secs\n'.format(time.time() - start)) 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /utils_.py: -------------------------------------------------------------------------------- 1 | # @Time : 2022/6/9 11:22 2 | # @Author : Yanjie WEN 3 | # @Institution : CSU & BUCEA 4 | # @IDE : Pycharm 5 | # @FileName : utils_ 6 | # @Project Name :metacode 7 | from copyreg import pickle 8 | from Hyperparameters import Hyperparameters as hp 9 | import numpy as np 10 | import pandas as pd 11 | import matplotlib.pyplot as plt 12 | import networkx as nx 13 | from sklearn.preprocessing import MinMaxScaler 14 | import pickle 15 | 16 | #building a graph 17 | def build_graph(path_file): 18 | """ 19 | 20 | :param path_file: the adj matrix file path 21 | :return: a graph with shape with the graph informations 22 | """ 23 | df = pd.read_csv(path_file) 24 | graph = nx.DiGraph()#a garph with direction 25 | edges = df.values 26 | graph.add_weighted_edges_from(edges) 27 | #visualiazation 28 | pos=nx.random_layout(graph) # gen the location of nodes 29 | plt.rcParams['figure.figsize']= (6, 4) # picture size 30 | nx.draw(graph,pos,with_labels=True, node_color='white', edge_color='red', node_size=15, alpha=0.5 ) 31 | plt.title('Self_Define Net',fontsize=18) 32 | plt.show() 33 | return graph 34 | 35 | def toarray(data):#get array 36 | if isinstance(data,list): 37 | return np.array(data) 38 | else: 39 | return data 40 | 41 | def minmaxsca(datas): 42 | """ 43 | 44 | :param datas: (points,sensors,flow) 45 | :return:(points,sensors,flow),a log 46 | """ 47 | n,s,f = datas.shape 48 | sca = MinMaxScaler(feature_range=(-1,1)) 49 | datas = sca.fit_transform(datas.reshape(-1,f)).reshape(-1,s,f) 50 | return datas,sca 51 | def inverse_minmaxsca(datas,sca): 52 | """ 53 | 54 | :param datas: (points,sensors,flow) after minmax 55 | :param sca: a log feature 56 | :return: to inverse the datas to the original value data 57 | """ 58 | n, s, f = datas.shape 59 | data = sca.inverse_transform(datas.reshape(-1,f)).reshape(-1,s,f) 60 | return data 61 | 62 | def sim_distance(adj_matrix,df): 63 | """ 64 | 65 | :param adj_matrix: (N,N) 66 | :param df: distance dataframe 67 | :return: (N,N) 68 | """ 69 | W2 = adj_matrix*adj_matrix 70 | dis_sim = np.exp(-W2 / (np.std(df['cost']) * np.std(df['cost']))) 71 | dis_sim[np.where(dis_sim <= hp.eplision)] = 1#first set value==1 if <=eplision, then set 1==0 inverse select 72 | dis_sim[np.where(dis_sim == 1)] = 0 73 | return dis_sim 74 | 75 | def time_sim_matrix(datas): 76 | """ 77 | 78 | :param datas: origins data 79 | :return: (N,N) 80 | """ 81 | num_sensors = datas.shape[1] 82 | sim_init = np.zeros((num_sensors, num_sensors)) 83 | for i in range(num_sensors): 84 | for j in range(num_sensors): 85 | if i!=j: 86 | set_1 = datas[:24 * 7, i, :]#a week data, more data in not needed 87 | set_2 = datas[:24 * 7, j, :] 88 | # time_sim = np.sum(set_1 * set_2) / (np.linalg.norm(set_1) * np.linalg.norm(set_2)) 89 | time_sim = np.exp(-(np.linalg.norm(set_1-set_2)/min(np.linalg.norm(set_1), 90 | np.linalg.norm(set_2)))) 91 | sim_init[i, j] = time_sim 92 | sim_init[np.where(sim_init <= hp.time_eplision)] = 0 93 | diag = np.identity(sim_init.shape[0])#add self-loop 94 | sim_init+=diag 95 | x_norm = sim_init / np.sum(sim_init, axis=1,keepdims=True) 96 | return x_norm 97 | 98 | def norm_adjmatrix(cal_matrix): 99 | """ 100 | row norm for distsim_matrix 101 | :param cal_matrix: (N,N),after calculating the matrix 102 | :return: (N,N) 103 | """ 104 | for i in range(cal_matrix.shape[0]): 105 | for j in range(cal_matrix.shape[1]): 106 | if i==j: 107 | cal_matrix[i,j]=0 108 | diag = np.identity(cal_matrix.shape[0]) 109 | cal_matrix+=diag#add self-loop 110 | x_norm = cal_matrix/np.sum(cal_matrix,axis=1,keepdims=True) 111 | return x_norm 112 | 113 | def split_train_val_test(data): 114 | """ 115 | shuffle the data sets 116 | :param data: (s,t,n,d) 117 | :return: train,val,test 118 | """ 119 | np.random.seed(42) 120 | shuffle_id = np.random.permutation(len(data)) 121 | train_datas_id = shuffle_id[:int(len(data)*0.6)] 122 | val_datas_id = shuffle_id[int(len(data)*0.6):(int(len(data)*0.2)+int(len(data)*0.6))] 123 | test_datas_id = shuffle_id[int(len(data)*0.2)+int(len(data)*0.6):] 124 | return data[train_datas_id],data[val_datas_id],data[test_datas_id] 125 | 126 | def read_pkl(data_path): 127 | with open(data_path,'rb') as f: 128 | return pickle.load(f) 129 | 130 | #cal metrcis->(bn,t)->(t,) 131 | def cal_mae(gt,pred): 132 | mae = np.abs(gt-pred) 133 | return np.mean(mae,axis=0) 134 | def cal_rmse(gt,pred): 135 | rmse = np.sqrt(np.mean(np.square(gt-pred),axis=0)) 136 | return rmse 137 | def cal_mape(gt,pred): 138 | multi_time_step = [] 139 | for t in range(gt.shape[1]): 140 | gt_t = gt[:,t] 141 | pred_t = pred[:,t] 142 | gt_t_se = gt_t[np.where(gt_t>=10)] 143 | pred_t_se = pred_t[np.where(gt_t>=10)] 144 | mape = np.mean(np.abs(gt_t_se-pred_t_se)/np.abs(gt_t_se)) 145 | multi_time_step.append(mape) 146 | return multi_time_step 147 | def cal_wmape(gt,pred): 148 | sum_gt = np.sum(gt,axis=0) 149 | abs_diff = np.sum(np.abs(pred-gt),axis=0) 150 | return abs_diff/sum_gt 151 | --------------------------------------------------------------------------------