├── .gitignore ├── .pylintrc ├── LICENSE ├── README.md ├── basicts ├── __init__.py ├── archs │ ├── __init__.py │ └── arch_zoo │ │ ├── agcrn_arch │ │ ├── __init__.py │ │ ├── agcn.py │ │ ├── agcrn_arch.py │ │ └── agcrn_cell.py │ │ ├── autoformer_arch │ │ ├── __init__.py │ │ ├── auto_correlation.py │ │ ├── autoformer_arch.py │ │ ├── embed.py │ │ └── enc_dec.py │ │ ├── d2stgnn_arch │ │ ├── __init__.py │ │ ├── d2stgnn_arch.py │ │ ├── decouple │ │ │ ├── estimation_gate.py │ │ │ └── residual_decomp.py │ │ ├── difusion_block │ │ │ ├── __init__.py │ │ │ ├── dif_block.py │ │ │ ├── dif_model.py │ │ │ └── forecast.py │ │ ├── dynamic_graph_conv │ │ │ ├── dy_graph_conv.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── distance.py │ │ │ │ ├── mask.py │ │ │ │ └── normalizer.py │ │ └── inherent_block │ │ │ ├── __init__.py │ │ │ ├── forecast.py │ │ │ ├── inh_block.py │ │ │ └── inh_model.py │ │ ├── dcrnn_arch │ │ ├── __init__.py │ │ ├── dcrnn_arch.py │ │ └── dcrnn_cell.py │ │ ├── dgcrn_arch │ │ ├── __init__.py │ │ ├── dgcrn_arch.py │ │ └── dgcrn_layer.py │ │ ├── fedformer_arch │ │ ├── __init__.py │ │ ├── auto_correlation.py │ │ ├── embed.py │ │ ├── fedformer_arch.py │ │ ├── fedformer_enc_dec.py │ │ ├── fourier_correlation.py │ │ ├── masking.py │ │ ├── multi_wavelet_correlation.py │ │ ├── self_attention_family.py │ │ └── utils.py │ │ ├── gts_arch │ │ ├── __init__.py │ │ ├── gts_arch.py │ │ └── gts_cell.py │ │ ├── gwnet_arch │ │ ├── __init__.py │ │ └── gwnet_arch.py │ │ ├── hi_arch │ │ ├── __init__.py │ │ └── hi_arch.py │ │ ├── informer_arch │ │ ├── __init__.py │ │ ├── attn.py │ │ ├── decoder.py │ │ ├── embed.py │ │ ├── encoder.py │ │ ├── informer_arch.py │ │ └── masking.py │ │ ├── linear_arch │ │ ├── __init__.py │ │ ├── dlinear.py │ │ ├── linear.py │ │ └── nlinear.py │ │ ├── megacrn │ │ ├── __init__.py │ │ └── megacrn_arch.py │ │ ├── mtgnn_arch │ │ ├── __init__.py │ │ ├── mtgnn_arch.py │ │ └── mtgnn_layers.py │ │ ├── pyraformer_arch │ │ ├── __init__.py │ │ ├── embed.py │ │ ├── hierarchical_mm_tvm.py │ │ ├── layers.py │ │ ├── modules.py │ │ ├── pam_tvm.py │ │ ├── pyraformer_arch.py │ │ └── sub_layers.py │ │ ├── stemgnn_arch │ │ ├── __init__.py │ │ └── stemgnn_arch.py │ │ ├── stgcn_arch │ │ ├── __init__.py │ │ ├── stgcn_arch.py │ │ └── stgcn_layers.py │ │ ├── stid_arch │ │ ├── __init__.py │ │ ├── mlp.py │ │ └── stid_arch.py │ │ ├── stnorm_arch │ │ ├── __init__.py │ │ └── stnorm_arch.py │ │ └── utils │ │ ├── __init__.py │ │ └── xformer.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── registry.py │ └── transform.py ├── launcher.py ├── losses │ ├── __init__.py │ └── losses.py ├── metrics │ ├── __init__.py │ ├── mae.py │ ├── mape.py │ └── rmse.py ├── runners │ ├── __init__.py │ ├── base_runner.py │ ├── base_tsf_runner.py │ └── runner_zoo │ │ ├── agcrn_runner.py │ │ ├── autoformer_runner.py │ │ ├── d2stgnn_runner.py │ │ ├── dcrnn_runner.py │ │ ├── dgcrn_runner.py │ │ ├── fedformer_runner.py │ │ ├── gts_runner.py │ │ ├── gwnet_runner.py │ │ ├── hi_runner.py │ │ ├── informer_runner.py │ │ ├── linear_runner.py │ │ ├── megecrn_runner.py │ │ ├── mtgnn_runner.py │ │ ├── pyraformer_runner.py │ │ ├── simple_tsf_runner.py │ │ ├── stemgnn_runner.py │ │ ├── stgcn_runner.py │ │ ├── stid_runner.py │ │ └── stnorm_runner.py └── utils │ ├── __init__.py │ ├── adjacent_matrix_norm.py │ ├── misc.py │ └── serialization.py ├── datasets └── README.md ├── figure ├── Inspecting.jpg ├── MainResults.png └── STEP.png ├── requirements.txt ├── scripts └── data_preparation │ ├── METR-LA │ └── generate_training_data.py │ ├── PEMS-BAY │ └── generate_training_data.py │ ├── PEMS03 │ ├── generate_adj_mx.py │ └── generate_training_data.py │ ├── PEMS04 │ ├── generate_adj_mx.py │ └── generate_training_data.py │ ├── PEMS07 │ ├── generate_adj_mx.py │ └── generate_training_data.py │ ├── PEMS08 │ ├── generate_adj_mx.py │ └── generate_training_data.py │ └── all.sh ├── step ├── STEP_METR-LA.py ├── STEP_PEMS-BAY.py ├── STEP_PEMS03.py ├── STEP_PEMS04.py ├── STEP_PEMS07.py ├── STEP_PEMS08.py ├── TSFormer_METR-LA.py ├── TSFormer_PEMS-BAY.py ├── TSFormer_PEMS03.py ├── TSFormer_PEMS04.py ├── TSFormer_PEMS07.py ├── TSFormer_PEMS08.py ├── run.py ├── step_arch │ ├── __init__.py │ ├── discrete_graph_learning.py │ ├── graphwavenet │ │ ├── __init__.py │ │ └── model.py │ ├── similarity.py │ ├── step.py │ └── tsformer │ │ ├── __init__.py │ │ ├── mask.py │ │ ├── patch.py │ │ ├── positional_encoding.py │ │ ├── transformer_layers.py │ │ └── tsformer.py ├── step_data │ ├── __init__.py │ ├── forecasting_dataset.py │ └── pretraining_dataset.py ├── step_loss │ ├── __init__.py │ └── step_loss.py └── step_runner │ ├── __init__.py │ ├── step_runner.py │ └── tsformer_runner.py ├── test └── test_inference.py ├── training_logs ├── STEP_METR-LA.log ├── STEP_PEMS-BAY.log ├── STEP_PEMS04.log ├── STEP_PEMS08.log ├── TSFormer_METR-LA.log ├── TSFormer_PEMS-BAY.log ├── TSFormer_PEMS03.log ├── TSFormer_PEMS04.log ├── TSFormer_PEMS07.log └── TSFormer_PEMS08.log └── tsformer_ckpt ├── TSFormer_METR-LA.pt ├── TSFormer_PEMS-BAY.pt ├── TSFormer_PEMS03.pt ├── TSFormer_PEMS04.pt ├── TSFormer_PEMS07.pt └── TSFormer_PEMS08.pt /.gitignore: -------------------------------------------------------------------------------- 1 | # dir 2 | __pycache__/ 3 | .vscode/ 4 | checkpoints/ 5 | datasets/raw_data 6 | todo.md 7 | gpu_task.py 8 | cmd.sh 9 | 10 | # file 11 | *.npz 12 | *.npy 13 | *.csv 14 | *.pkl 15 | *.h5 16 | *.pt 17 | core* 18 | *.p 19 | *.pickle 20 | *.pyc 21 | *.txt 22 | 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | # *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | lib/ 38 | lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | wheels/ 43 | share/python-wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .nox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | *.py,cover 70 | .hypothesis/ 71 | .pytest_cache/ 72 | cover/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | db.sqlite3-journal 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | 94 | # PyBuilder 95 | .pybuilder/ 96 | target/ 97 | 98 | # Jupyter Notebook 99 | .ipynb_checkpoints 100 | 101 | # IPython 102 | profile_default/ 103 | ipython_config.py 104 | 105 | # pyenv 106 | # For a library or package, you might want to ignore these files since the code is 107 | # intended to run in multiple environments; otherwise, check them in: 108 | # .python-version 109 | 110 | # pipenv 111 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 112 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 113 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 114 | # install all needed dependencies. 115 | #Pipfile.lock 116 | 117 | # poetry 118 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 119 | # This is especially recommended for binary packages to ensure reproducibility, and is more 120 | # commonly ignored for libraries. 121 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 122 | #poetry.lock 123 | 124 | # pdm 125 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 126 | #pdm.lock 127 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 128 | # in version control. 129 | # https://pdm.fming.dev/#use-with-ide 130 | .pdm.toml 131 | 132 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 133 | __pypackages__/ 134 | 135 | # Celery stuff 136 | celerybeat-schedule 137 | celerybeat.pid 138 | 139 | # SageMath parsed files 140 | *.sage.py 141 | 142 | # Environments 143 | .env 144 | .venv 145 | env/ 146 | venv/ 147 | ENV/ 148 | env.bak/ 149 | venv.bak/ 150 | 151 | # Spyder project settings 152 | .spyderproject 153 | .spyproject 154 | 155 | # Rope project settings 156 | .ropeproject 157 | 158 | # mkdocs documentation 159 | /site 160 | 161 | # mypy 162 | .mypy_cache/ 163 | .dmypy.json 164 | dmypy.json 165 | 166 | # Pyre type checker 167 | .pyre/ 168 | 169 | # pytype static type analyzer 170 | .pytype/ 171 | 172 | # Cython debug symbols 173 | cython_debug/ 174 | 175 | # PyCharm 176 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 177 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 178 | # and can be added to the global gitignore or merged into this file. For a more nuclear 179 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. -------------------------------------------------------------------------------- /basicts/__init__.py: -------------------------------------------------------------------------------- 1 | from .launcher import launch_training 2 | 3 | __version__ = "0.1.6" 4 | 5 | __all__ = ["__version__", "launch_training"] 6 | -------------------------------------------------------------------------------- /basicts/archs/__init__.py: -------------------------------------------------------------------------------- 1 | from .arch_zoo.stid_arch import STID 2 | from .arch_zoo.gwnet_arch import GraphWaveNet 3 | from .arch_zoo.dcrnn_arch import DCRNN 4 | from .arch_zoo.d2stgnn_arch import D2STGNN 5 | from .arch_zoo.stgcn_arch import STGCN 6 | from .arch_zoo.mtgnn_arch import MTGNN 7 | from .arch_zoo.stnorm_arch import STNorm 8 | from .arch_zoo.agcrn_arch import AGCRN 9 | from .arch_zoo.stemgnn_arch import StemGNN 10 | from .arch_zoo.gts_arch import GTS 11 | from .arch_zoo.dgcrn_arch import DGCRN 12 | from .arch_zoo.linear_arch import Linear, DLinear, NLinear 13 | from .arch_zoo.autoformer_arch import Autoformer 14 | from .arch_zoo.hi_arch import HINetwork 15 | from .arch_zoo.fedformer_arch import FEDformer 16 | from .arch_zoo.informer_arch import Informer, InformerStack 17 | from .arch_zoo.pyraformer_arch import Pyraformer 18 | from .arch_zoo.megacrn import MegaCRN 19 | 20 | __all__ = ["STID", "GraphWaveNet", "DCRNN", 21 | "D2STGNN", "STGCN", "MTGNN", 22 | "STNorm", "AGCRN", "StemGNN", 23 | "GTS", "DGCRN", "Linear", 24 | "DLinear", "NLinear", "Autoformer", 25 | "HINetwork", "FEDformer", "Informer", 26 | "InformerStack", "Pyraformer", 27 | "MegaCRN"] 28 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/agcrn_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .agcrn_arch import AGCRN 2 | 3 | __all__ = ["AGCRN"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/agcrn_arch/agcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | 6 | class AVWGCN(nn.Module): 7 | def __init__(self, dim_in, dim_out, cheb_k, embed_dim): 8 | super(AVWGCN, self).__init__() 9 | self.cheb_k = cheb_k 10 | self.weights_pool = nn.Parameter( 11 | torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)) 12 | self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) 13 | 14 | def forward(self, x, node_embeddings): 15 | # x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N] 16 | # output shape [B, N, C] 17 | node_num = node_embeddings.shape[0] 18 | supports = F.softmax( 19 | F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1) 20 | support_set = [torch.eye(node_num).to(supports.device), supports] 21 | # default cheb_k = 3 22 | for k in range(2, self.cheb_k): 23 | support_set.append(torch.matmul( 24 | 2 * supports, support_set[-1]) - support_set[-2]) 25 | supports = torch.stack(support_set, dim=0) 26 | # N, cheb_k, dim_in, dim_out 27 | weights = torch.einsum( 28 | 'nd,dkio->nkio', node_embeddings, self.weights_pool) 29 | bias = torch.matmul(node_embeddings, self.bias_pool) # N, dim_out 30 | x_g = torch.einsum("knm,bmc->bknc", supports, 31 | x) # B, cheb_k, N, dim_in 32 | x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in 33 | x_gconv = torch.einsum('bnki,nkio->bno', x_g, 34 | weights) + bias # b, N, dim_out 35 | return x_gconv 36 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/agcrn_arch/agcrn_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .agcrn_cell import AGCRNCell 5 | 6 | 7 | class AVWDCRNN(nn.Module): 8 | def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): 9 | super(AVWDCRNN, self).__init__() 10 | assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.' 11 | self.node_num = node_num 12 | self.input_dim = dim_in 13 | self.num_layers = num_layers 14 | self.dcrnn_cells = nn.ModuleList() 15 | self.dcrnn_cells.append( 16 | AGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim)) 17 | for _ in range(1, num_layers): 18 | self.dcrnn_cells.append( 19 | AGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim)) 20 | 21 | def forward(self, x, init_state, node_embeddings): 22 | # shape of x: (B, T, N, D) 23 | # shape of init_state: (num_layers, B, N, hidden_dim) 24 | assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim 25 | seq_length = x.shape[1] 26 | current_inputs = x 27 | output_hidden = [] 28 | for i in range(self.num_layers): 29 | state = init_state[i] 30 | inner_states = [] 31 | for t in range(seq_length): 32 | state = self.dcrnn_cells[i]( 33 | current_inputs[:, t, :, :], state, node_embeddings) 34 | inner_states.append(state) 35 | output_hidden.append(state) 36 | current_inputs = torch.stack(inner_states, dim=1) 37 | # current_inputs: the outputs of last layer: (B, T, N, hidden_dim) 38 | # output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim) 39 | #last_state: (B, N, hidden_dim) 40 | return current_inputs, output_hidden 41 | 42 | def init_hidden(self, batch_size): 43 | init_states = [] 44 | for i in range(self.num_layers): 45 | init_states.append( 46 | self.dcrnn_cells[i].init_hidden_state(batch_size)) 47 | # (num_layers, B, N, hidden_dim) 48 | return torch.stack(init_states, dim=0) 49 | 50 | 51 | class AGCRN(nn.Module): 52 | """ 53 | Paper: Adaptive Graph Convolutional Recurrent Network for Traffic Forecasting 54 | Official Code: https://github.com/LeiBAI/AGCRN 55 | Link: https://arxiv.org/abs/2007.02842 56 | """ 57 | 58 | def __init__(self, num_nodes, input_dim, rnn_units, output_dim, horizon, num_layers, default_graph, embed_dim, cheb_k): 59 | super(AGCRN, self).__init__() 60 | self.num_node = num_nodes 61 | self.input_dim = input_dim 62 | self.hidden_dim = rnn_units 63 | self.output_dim = output_dim 64 | self.horizon = horizon 65 | self.num_layers = num_layers 66 | 67 | self.default_graph = default_graph 68 | self.node_embeddings = nn.Parameter(torch.randn( 69 | self.num_node, embed_dim), requires_grad=True) 70 | 71 | self.encoder = AVWDCRNN(num_nodes, input_dim, rnn_units, cheb_k, 72 | embed_dim, num_layers) 73 | 74 | # predictor 75 | self.end_conv = nn.Conv2d( 76 | 1, horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) 77 | 78 | self.init_param() 79 | 80 | def init_param(self): 81 | for p in self.parameters(): 82 | if p.dim() > 1: 83 | nn.init.xavier_uniform_(p) 84 | else: 85 | nn.init.uniform_(p) 86 | 87 | def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_seen: int, epoch: int, train: bool, **kwargs) -> torch.Tensor: 88 | """Feedforward function of AGCRN. 89 | 90 | Args: 91 | history_data (torch.Tensor): inputs with shape [B, L, N, C]. 92 | 93 | Returns: 94 | torch.Tensor: outputs with shape [B, L, N, C] 95 | """ 96 | 97 | init_state = self.encoder.init_hidden(history_data.shape[0]) 98 | output, _ = self.encoder( 99 | history_data, init_state, self.node_embeddings) # B, T, N, hidden 100 | output = output[:, -1:, :, :] # B, 1, N, hidden 101 | 102 | # CNN based predictor 103 | output = self.end_conv((output)) # B, T*C, N, 1 104 | output = output.squeeze(-1).reshape(-1, self.horizon, 105 | self.output_dim, self.num_node) 106 | output = output.permute(0, 1, 3, 2) # B, T, N, C 107 | 108 | return output 109 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/agcrn_arch/agcrn_cell.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .agcn import AVWGCN 5 | 6 | 7 | class AGCRNCell(nn.Module): 8 | def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim): 9 | super(AGCRNCell, self).__init__() 10 | self.node_num = node_num 11 | self.hidden_dim = dim_out 12 | self.gate = AVWGCN(dim_in+self.hidden_dim, 2 * 13 | dim_out, cheb_k, embed_dim) 14 | self.update = AVWGCN(dim_in+self.hidden_dim, 15 | dim_out, cheb_k, embed_dim) 16 | 17 | def forward(self, x, state, node_embeddings): 18 | # x: B, num_nodes, input_dim 19 | # state: B, num_nodes, hidden_dim 20 | state = state.to(x.device) 21 | input_and_state = torch.cat((x, state), dim=-1) 22 | z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings)) 23 | z, r = torch.split(z_r, self.hidden_dim, dim=-1) 24 | candidate = torch.cat((x, z*state), dim=-1) 25 | hc = torch.tanh(self.update(candidate, node_embeddings)) 26 | h = r*state + (1-r)*hc 27 | return h 28 | 29 | def init_hidden_state(self, batch_size): 30 | return torch.zeros(batch_size, self.node_num, self.hidden_dim) 31 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/autoformer_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoformer_arch import Autoformer 2 | 3 | __all__ = ["Autoformer"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .d2stgnn_arch import D2STGNN 2 | 3 | __all__ = ["D2STGNN"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/decouple/estimation_gate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class EstimationGate(nn.Module): 5 | r""" 6 | The spatial gate module. 7 | """ 8 | def __init__(self, node_emb_dim, time_emb_dim, hidden_dim, input_seq_len): 9 | super().__init__() 10 | self.FC1 = nn.Linear(2 * node_emb_dim + time_emb_dim * 2, hidden_dim) 11 | self.act = nn.ReLU() 12 | self.FC2 = nn.Linear(hidden_dim, 1) 13 | 14 | def forward(self, node_embedding1, node_embedding2, T_D, D_W, X): 15 | B, L, N, D = T_D.shape 16 | spatial_gate_feat = torch.cat([T_D, D_W, node_embedding1.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1), node_embedding2.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1)], dim=-1) 17 | hidden = self.FC1(spatial_gate_feat) 18 | hidden = self.act(hidden) 19 | # activation 20 | spatial_gate = torch.sigmoid(self.FC2(hidden))[:, -X.shape[1]:, :, :] 21 | X = X * spatial_gate 22 | return X 23 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/decouple/residual_decomp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class ResidualDecomp(nn.Module): 4 | r""" 5 | Residual decomposition. 6 | """ 7 | def __init__(self, input_shape): 8 | super().__init__() 9 | self.ln = nn.LayerNorm(input_shape[-1]) 10 | self.ac = nn.ReLU() 11 | 12 | def forward(self, x, y): 13 | u = x - self.ac(y) 14 | u = self.ln(u) 15 | return u 16 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/difusion_block/__init__.py: -------------------------------------------------------------------------------- 1 | from ..difusion_block.dif_block import DifBlock 2 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/difusion_block/dif_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ..decouple.residual_decomp import ResidualDecomp 4 | from .forecast import Forecast 5 | from .dif_model import STLocalizedConv 6 | 7 | 8 | class DifBlock(nn.Module): 9 | def __init__(self, hidden_dim, fk_dim=256, use_pre=None, dy_graph=None, sta_graph=None, **model_args): 10 | super().__init__() 11 | self.pre_defined_graph = model_args['adjs'] 12 | self.localized_st_conv = STLocalizedConv(hidden_dim, pre_defined_graph=self.pre_defined_graph, \ 13 | use_pre=use_pre, dy_graph=dy_graph, sta_graph=sta_graph, **model_args) 14 | # sub and norm 15 | self.residual_decompose = ResidualDecomp([-1, -1, -1, hidden_dim]) 16 | # forecast 17 | self.forecast_branch = Forecast( 18 | hidden_dim, fk_dim=fk_dim, **model_args) 19 | # backcast 20 | self.backcast_branch = nn.Linear(hidden_dim, hidden_dim) 21 | 22 | def forward(self, X, X_spa, dynamic_graph, static_graph): 23 | Z = self.localized_st_conv(X_spa, dynamic_graph, static_graph) 24 | # forecast branch 25 | forecast_hidden = self.forecast_branch( 26 | X_spa, Z, self.localized_st_conv, dynamic_graph, static_graph) 27 | # backcast branch 28 | backcast_seq = self.backcast_branch(Z) 29 | # Residual Decomposition 30 | backcast_seq = backcast_seq 31 | X = X[:, -backcast_seq.shape[1]:, :, :] 32 | backcast_seq_res = self.residual_decompose(X, backcast_seq) 33 | 34 | return backcast_seq_res, forecast_hidden 35 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/difusion_block/dif_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class STLocalizedConv(nn.Module): 6 | def __init__(self, hidden_dim, pre_defined_graph=None, use_pre=None, dy_graph=None, sta_graph=None, **model_args): 7 | super().__init__() 8 | # gated temporal conv 9 | self.k_s = model_args['k_s'] 10 | self.k_t = model_args['k_t'] 11 | self.hidden_dim = hidden_dim 12 | 13 | # graph conv 14 | self.pre_defined_graph = pre_defined_graph 15 | self.use_predefined_graph = use_pre 16 | self.use_dynamic_hidden_graph = dy_graph 17 | self.use_static__hidden_graph = sta_graph 18 | 19 | self.support_len = len(self.pre_defined_graph) + \ 20 | int(dy_graph) + int(sta_graph) 21 | self.num_matric = (int(use_pre) * len(self.pre_defined_graph) + len( 22 | self.pre_defined_graph) * int(dy_graph) + int(sta_graph)) * self.k_s + 1 23 | self.dropout = nn.Dropout(model_args['dropout']) 24 | self.pre_defined_graph = self.get_graph(self.pre_defined_graph) 25 | 26 | self.fc_list_updt = nn.Linear( 27 | self.k_t * hidden_dim, self.k_t * hidden_dim, bias=False) 28 | self.gcn_updt = nn.Linear( 29 | self.hidden_dim*self.num_matric, self.hidden_dim) 30 | 31 | # others 32 | self.bn = nn.BatchNorm2d(self.hidden_dim) 33 | self.activation = nn.ReLU() 34 | 35 | def gconv(self, support, X_k, X_0): 36 | out = [X_0] 37 | for graph in support: 38 | if len(graph.shape) == 2: # staitic or predefined graph 39 | pass 40 | else: 41 | graph = graph.unsqueeze(1) 42 | H_k = torch.matmul(graph, X_k) 43 | out.append(H_k) 44 | out = torch.cat(out, dim=-1) 45 | out = self.gcn_updt(out) 46 | out = self.dropout(out) 47 | return out 48 | 49 | def get_graph(self, support): 50 | # Only used in static including static hidden graph and predefined graph, but not used for dynamic graph. 51 | graph_ordered = [] 52 | mask = 1 - torch.eye(support[0].shape[0]).to(support[0].device) 53 | for graph in support: 54 | k_1_order = graph # 1 order 55 | graph_ordered.append(k_1_order * mask) 56 | # e.g., order = 3, k=[2, 3]; order = 2, k=[2] 57 | for k in range(2, self.k_s+1): 58 | k_1_order = torch.matmul(graph, k_1_order) 59 | graph_ordered.append(k_1_order * mask) 60 | # get st localed graph 61 | st_local_graph = [] 62 | for graph in graph_ordered: 63 | graph = graph.unsqueeze(-2).expand(-1, self.k_t, -1) 64 | graph = graph.reshape( 65 | graph.shape[0], graph.shape[1] * graph.shape[2]) 66 | # [num_nodes, kernel_size x num_nodes] 67 | st_local_graph.append(graph) 68 | # [order, num_nodes, kernel_size x num_nodes] 69 | return st_local_graph 70 | 71 | def forward(self, X, dynamic_graph, static_graph): 72 | # X: [bs, seq, nodes, feat] 73 | # [bs, seq, num_nodes, ks, num_feat] 74 | X = X.unfold(1, self.k_t, 1).permute(0, 1, 2, 4, 3) 75 | # seq_len is changing 76 | batch_size, seq_len, num_nodes, kernel_size, num_feat = X.shape 77 | 78 | # support 79 | support = [] 80 | # predefined graph 81 | if self.use_predefined_graph: 82 | support = support + self.pre_defined_graph 83 | # dynamic graph 84 | if self.use_dynamic_hidden_graph: 85 | # k_order is caled in dynamic_graph_constructor component 86 | support = support + dynamic_graph 87 | # predefined graphs and static hidden graphs 88 | if self.use_static__hidden_graph: 89 | support = support + self.get_graph(static_graph) 90 | 91 | # parallelize 92 | X = X.reshape(batch_size, seq_len, num_nodes, kernel_size * num_feat) 93 | # batch_size, seq_len, num_nodes, kernel_size * hidden_dim 94 | out = self.fc_list_updt(X) 95 | out = self.activation(out) 96 | out = out.view(batch_size, seq_len, num_nodes, kernel_size, num_feat) 97 | X_0 = torch.mean(out, dim=-2) 98 | # batch_size, seq_len, kernel_size x num_nodes, hidden_dim 99 | X_k = out.transpose(-3, -2).reshape(batch_size, 100 | seq_len, kernel_size*num_nodes, num_feat) 101 | # Nx3N 3NxD -> NxD: batch_size, seq_len, num_nodes, hidden_dim 102 | hidden = self.gconv(support, X_k, X_0) 103 | return hidden 104 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/difusion_block/forecast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Forecast(nn.Module): 6 | def __init__(self, hidden_dim, fk_dim=None, **model_args): 7 | super().__init__() 8 | self.k_t = model_args['k_t'] 9 | self.output_seq_len = model_args['seq_length'] 10 | self.forecast_fc = nn.Linear(hidden_dim, fk_dim) 11 | self.model_args = model_args 12 | 13 | def forward(self, X, H, st_l_conv, dynamic_graph, static_graph): 14 | [B, seq_len_remain, B, D] = H.shape 15 | [B, seq_len_input, B, D] = X.shape 16 | 17 | predict = [] 18 | history = X 19 | predict.append(H[:, -1, :, :].unsqueeze(1)) 20 | for _ in range(int(self.output_seq_len / self.model_args['gap'])-1): 21 | _1 = predict[-self.k_t:] 22 | if len(_1) < self.k_t: 23 | sub = self.k_t - len(_1) 24 | _2 = history[:, -sub:, :, :] 25 | _1 = torch.cat([_2] + _1, dim=1) 26 | else: 27 | _1 = torch.cat(_1, dim=1) 28 | predict.append(st_l_conv(_1, dynamic_graph, static_graph)) 29 | predict = torch.cat(predict, dim=1) 30 | predict = self.forecast_fc(predict) 31 | return predict 32 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/dynamic_graph_conv/dy_graph_conv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .utils import * 4 | 5 | 6 | class DynamicGraphConstructor(nn.Module): 7 | def __init__(self, **model_args): 8 | super().__init__() 9 | # model args 10 | self.k_s = model_args['k_s'] # spatial order 11 | self.k_t = model_args['k_t'] # temporal kernel size 12 | # hidden dimension of 13 | self.hidden_dim = model_args['num_hidden'] 14 | # trainable node embedding dimension 15 | self.node_dim = model_args['node_hidden'] 16 | 17 | self.distance_function = DistanceFunction(**model_args) 18 | self.mask = Mask(**model_args) 19 | self.normalizer = Normalizer() 20 | self.multi_order = MultiOrder(order=self.k_s) 21 | 22 | def st_localization(self, graph_ordered): 23 | st_local_graph = [] 24 | for modality_i in graph_ordered: 25 | for k_order_graph in modality_i: 26 | k_order_graph = k_order_graph.unsqueeze( 27 | -2).expand(-1, -1, self.k_t, -1) 28 | k_order_graph = k_order_graph.reshape( 29 | k_order_graph.shape[0], k_order_graph.shape[1], k_order_graph.shape[2] * k_order_graph.shape[3]) 30 | st_local_graph.append(k_order_graph) 31 | return st_local_graph 32 | 33 | def forward(self, **inputs): 34 | X = inputs['X'] 35 | E_d = inputs['E_d'] 36 | E_u = inputs['E_u'] 37 | T_D = inputs['T_D'] 38 | D_W = inputs['D_W'] 39 | # distance calculation 40 | dist_mx = self.distance_function(X, E_d, E_u, T_D, D_W) 41 | # mask 42 | dist_mx = self.mask(dist_mx) 43 | # normalization 44 | dist_mx = self.normalizer(dist_mx) 45 | # multi order 46 | mul_mx = self.multi_order(dist_mx) 47 | # spatial temporal localization 48 | dynamic_graphs = self.st_localization(mul_mx) 49 | 50 | return dynamic_graphs 51 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/dynamic_graph_conv/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .mask import * 2 | from .normalizer import * 3 | from .distance import * 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/dynamic_graph_conv/utils/distance.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class DistanceFunction(nn.Module): 7 | def __init__(self, **model_args): 8 | super().__init__() 9 | # attributes 10 | self.hidden_dim = model_args['num_hidden'] 11 | self.node_dim = model_args['node_hidden'] 12 | self.time_slot_emb_dim = self.hidden_dim 13 | self.input_seq_len = model_args['seq_length'] 14 | # Time Series Feature Extraction 15 | self.dropout = nn.Dropout(model_args['dropout']) 16 | self.fc_ts_emb1 = nn.Linear(self.input_seq_len, self.hidden_dim * 2) 17 | self.fc_ts_emb2 = nn.Linear(self.hidden_dim * 2, self.hidden_dim) 18 | self.ts_feat_dim= self.hidden_dim 19 | # Time Slot Embedding Extraction 20 | self.time_slot_embedding = nn.Linear(model_args['time_emb_dim'], self.time_slot_emb_dim) 21 | # Distance Score 22 | self.all_feat_dim = self.ts_feat_dim + self.node_dim + model_args['time_emb_dim']*2 23 | self.WQ = nn.Linear(self.all_feat_dim, self.hidden_dim, bias=False) 24 | self.WK = nn.Linear(self.all_feat_dim, self.hidden_dim, bias=False) 25 | self.bn = nn.BatchNorm1d(self.hidden_dim*2) 26 | 27 | def reset_parameters(self): 28 | for q_vec in self.q_vecs: 29 | nn.init.xavier_normal_(q_vec.data) 30 | for bias in self.biases: 31 | nn.init.zeros_(bias.data) 32 | 33 | def forward(self, X, E_d, E_u, T_D, D_W): 34 | # last pooling 35 | T_D = T_D[:, -1, :, :] 36 | D_W = D_W[:, -1, :, :] 37 | # dynamic information 38 | X = X[:, :, :, 0].transpose(1, 2).contiguous() # X->[batch_size, seq_len, num_nodes]->[batch_size, num_nodes, seq_len] 39 | [batch_size, num_nodes, seq_len] = X.shape 40 | X = X.view(batch_size * num_nodes, seq_len) 41 | dy_feat = self.fc_ts_emb2(self.dropout(self.bn(F.relu(self.fc_ts_emb1(X))))) # [batchsize, num_nodes, hidden_dim] 42 | dy_feat = dy_feat.view(batch_size, num_nodes, -1) 43 | # node embedding 44 | emb1 = E_d.unsqueeze(0).expand(batch_size, -1, -1) 45 | emb2 = E_u.unsqueeze(0).expand(batch_size, -1, -1) 46 | # distance calculation 47 | X1 = torch.cat([dy_feat, T_D, D_W, emb1], dim=-1) # hidden state for calculating distance 48 | X2 = torch.cat([dy_feat, T_D, D_W, emb2], dim=-1) # hidden state for calculating distance 49 | X = [X1, X2] 50 | adjacent_list = [] 51 | for _ in X: 52 | Q = self.WQ(_) 53 | K = self.WK(_) 54 | QKT = torch.bmm(Q, K.transpose(-1, -2)) / math.sqrt(self.hidden_dim) 55 | W = torch.softmax(QKT, dim=-1) 56 | adjacent_list.append(W) 57 | return adjacent_list 58 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/dynamic_graph_conv/utils/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Mask(nn.Module): 5 | def __init__(self, **model_args): 6 | super().__init__() 7 | self.mask = model_args['adjs'] 8 | 9 | def _mask(self, index, adj): 10 | mask = self.mask[index] + torch.ones_like(self.mask[index]) * 1e-7 11 | return mask.to(adj.device) * adj 12 | 13 | def forward(self, adj): 14 | result = [] 15 | for index, _ in enumerate(adj): 16 | result.append(self._mask(index, _)) 17 | return result 18 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/dynamic_graph_conv/utils/normalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from basicts.utils.misc import remove_nan_inf 4 | 5 | class Normalizer(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def _norm(self, graph): 10 | degree = torch.sum(graph, dim=2) 11 | degree = remove_nan_inf(1 / degree) 12 | degree = torch.diag_embed(degree) 13 | P = torch.bmm(degree, graph) 14 | return P 15 | 16 | def forward(self, adj): 17 | return [self._norm(_) for _ in adj] 18 | 19 | class MultiOrder(nn.Module): 20 | def __init__(self, order=2): 21 | super().__init__() 22 | self.order = order 23 | 24 | def _multi_order(self, graph): 25 | graph_ordered = [] 26 | k_1_order = graph # 1 order 27 | mask = torch.eye(graph.shape[1]).to(graph.device) 28 | mask = 1 - mask 29 | graph_ordered.append(k_1_order * mask) 30 | for k in range(2, self.order+1): # e.g., order = 3, k=[2, 3]; order = 2, k=[2] 31 | k_1_order = torch.matmul(k_1_order, graph) 32 | graph_ordered.append(k_1_order * mask) 33 | return graph_ordered 34 | 35 | def forward(self, adj): 36 | return [self._multi_order(_) for _ in adj] 37 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/inherent_block/__init__.py: -------------------------------------------------------------------------------- 1 | from .inh_block import InhBlock 2 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/inherent_block/forecast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Forecast(nn.Module): 6 | def __init__(self, hidden_dim, fk_dim, **model_args): 7 | super().__init__() 8 | self.output_seq_len = model_args['seq_length'] 9 | self.model_args = model_args 10 | 11 | self.forecast_fc = nn.Linear(hidden_dim, fk_dim) 12 | 13 | def forward(self, X, RNN_H, Z, transformer_layer, rnn_layer, pe): 14 | [B, L, N, D] = X.shape 15 | [L, B_N, D] = RNN_H.shape 16 | [L, B_N, D] = Z.shape 17 | 18 | predict = [Z[-1, :, :].unsqueeze(0)] 19 | for _ in range(int(self.output_seq_len / self.model_args['gap'])-1): 20 | # RNN 21 | _gru = rnn_layer.gru_cell(predict[-1][0], RNN_H[-1]).unsqueeze(0) 22 | RNN_H = torch.cat([RNN_H, _gru], dim=0) 23 | # Positional Encoding 24 | if pe is not None: 25 | RNN_H = pe(RNN_H) 26 | # Transformer 27 | _Z = transformer_layer(_gru, K=RNN_H, V=RNN_H) 28 | predict.append(_Z) 29 | 30 | predict = torch.cat(predict, dim=0) 31 | predict = predict.reshape(-1, B, N, D) 32 | predict = predict.transpose(0, 1) 33 | predict = self.forecast_fc(predict) 34 | return predict 35 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/inherent_block/inh_block.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ..decouple.residual_decomp import ResidualDecomp 6 | from .inh_model import RNNLayer, TransformerLayer 7 | from .forecast import Forecast 8 | 9 | 10 | class PositionalEncoding(nn.Module): 11 | def __init__(self, d_model, dropout=None, max_len: int = 5000): 12 | super().__init__() 13 | self.dropout = nn.Dropout(p=dropout) 14 | position = torch.arange(max_len).unsqueeze(1) 15 | div_term = torch.exp(torch.arange(0, d_model, 2) 16 | * (-math.log(10000.0) / d_model)) 17 | pe = torch.zeros(max_len, 1, d_model) 18 | pe[:, 0, 0::2] = torch.sin(position * div_term) 19 | pe[:, 0, 1::2] = torch.cos(position * div_term) 20 | self.register_buffer('pe', pe) 21 | 22 | def forward(self, X): 23 | X = X + self.pe[:X.size(0)] 24 | X = self.dropout(X) 25 | return X 26 | 27 | 28 | class InhBlock(nn.Module): 29 | def __init__(self, hidden_dim, num_heads=4, bias=True, fk_dim=256, first=None, **model_args): 30 | super().__init__() 31 | self.num_feat = hidden_dim 32 | self.hidden_dim = hidden_dim 33 | 34 | if first: 35 | self.pos_encoder = PositionalEncoding( 36 | hidden_dim, model_args['dropout']) 37 | else: 38 | self.pos_encoder = None 39 | self.rnn_layer = RNNLayer(hidden_dim, model_args['dropout']) 40 | self.transformer_layer = TransformerLayer( 41 | hidden_dim, num_heads, model_args['dropout'], bias) 42 | # forecast 43 | self.forecast_block = Forecast(hidden_dim, fk_dim, **model_args) 44 | # backcast 45 | self.backcast_fc = nn.Linear(hidden_dim, hidden_dim) 46 | # sub residual 47 | self.sub_and_norm = ResidualDecomp([-1, -1, -1, hidden_dim]) 48 | 49 | def forward(self, X): 50 | [batch_size, seq_len, num_nodes, num_feat] = X.shape 51 | # Temporal Model 52 | # RNN 53 | RNN_H_raw = self.rnn_layer(X) 54 | # Positional Encoding 55 | if self.pos_encoder is not None: 56 | RNN_H = self.pos_encoder(RNN_H_raw) 57 | else: 58 | RNN_H = RNN_H_raw 59 | # MultiHead Self Attention 60 | Z = self.transformer_layer(RNN_H, RNN_H, RNN_H) 61 | 62 | # forecast branch 63 | forecast_hidden = self.forecast_block( 64 | X, RNN_H_raw, Z, self.transformer_layer, self.rnn_layer, self.pos_encoder) 65 | 66 | # backcast branch 67 | Z = Z.reshape(seq_len, batch_size, num_nodes, num_feat) 68 | Z = Z.transpose(0, 1) 69 | backcast_seq = self.backcast_fc(Z) 70 | backcast_seq_res = self.sub_and_norm(X, backcast_seq) 71 | 72 | return backcast_seq_res, forecast_hidden 73 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/d2stgnn_arch/inherent_block/inh_model.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | from torch.nn import MultiheadAttention 4 | 5 | 6 | class RNNLayer(nn.Module): 7 | def __init__(self, hidden_dim, dropout=None): 8 | super().__init__() 9 | self.hidden_dim = hidden_dim 10 | self.gru_cell = nn.GRUCell(hidden_dim, hidden_dim) 11 | self.dropout = nn.Dropout(dropout) 12 | 13 | def forward(self, X): 14 | [batch_size, seq_len, num_nodes, hidden_dim] = X.shape 15 | X = X.transpose(1, 2).reshape( 16 | batch_size * num_nodes, seq_len, hidden_dim) 17 | hx = th.zeros_like(X[:, 0, :]) 18 | output = [] 19 | for _ in range(X.shape[1]): 20 | hx = self.gru_cell(X[:, _, :], hx) 21 | output.append(hx) 22 | output = th.stack(output, dim=0) 23 | output = self.dropout(output) 24 | return output 25 | 26 | 27 | class TransformerLayer(nn.Module): 28 | def __init__(self, hidden_dim, num_heads=4, dropout=None, bias=True): 29 | super().__init__() 30 | self.multi_head_self_attention = MultiheadAttention( 31 | hidden_dim, num_heads, dropout=dropout, bias=bias) 32 | self.dropout = nn.Dropout(dropout) 33 | 34 | def forward(self, X, K, V): 35 | Z = self.multi_head_self_attention(X, K, V)[0] 36 | Z = self.dropout(Z) 37 | return Z 38 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/dcrnn_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .dcrnn_arch import DCRNN 2 | 3 | __all__ = ['DCRNN'] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/dgcrn_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .dgcrn_arch import DGCRN 2 | 3 | __all__ = ["DGCRN"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/dgcrn_arch/dgcrn_layer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class gconv_RNN(nn.Module): 9 | def __init__(self): 10 | super(gconv_RNN, self).__init__() 11 | 12 | def forward(self, x, A): 13 | 14 | x = torch.einsum('nvc,nvw->nwc', (x, A)) 15 | return x.contiguous() 16 | 17 | 18 | class gconv_hyper(nn.Module): 19 | def __init__(self): 20 | super(gconv_hyper, self).__init__() 21 | 22 | def forward(self, x, A): 23 | A = A.to(x.device) 24 | x = torch.einsum('nvc,vw->nwc', (x, A)) 25 | return x.contiguous() 26 | 27 | 28 | class gcn(nn.Module): 29 | def __init__(self, dims, gdep, dropout, alpha, beta, gamma, type=None): 30 | super(gcn, self).__init__() 31 | if type == 'RNN': 32 | self.gconv = gconv_RNN() 33 | self.gconv_preA = gconv_hyper() 34 | self.mlp = nn.Linear((gdep + 1) * dims[0], dims[1]) 35 | 36 | elif type == 'hyper': 37 | self.gconv = gconv_hyper() 38 | self.mlp = nn.Sequential( 39 | OrderedDict([('fc1', nn.Linear((gdep + 1) * dims[0], dims[1])), 40 | ('sigmoid1', nn.Sigmoid()), 41 | ('fc2', nn.Linear(dims[1], dims[2])), 42 | ('sigmoid2', nn.Sigmoid()), 43 | ('fc3', nn.Linear(dims[2], dims[3]))])) 44 | 45 | self.gdep = gdep 46 | self.alpha = alpha 47 | self.beta = beta 48 | self.gamma = gamma 49 | self.type_GNN = type 50 | 51 | def forward(self, x, adj): 52 | 53 | h = x 54 | out = [h] 55 | if self.type_GNN == 'RNN': 56 | for _ in range(self.gdep): 57 | h = self.alpha * x + self.beta * self.gconv( 58 | h, adj[0]) + self.gamma * self.gconv_preA(h, adj[1]) 59 | out.append(h) 60 | else: 61 | for _ in range(self.gdep): 62 | h = self.alpha * x + self.gamma * self.gconv(h, adj) 63 | out.append(h) 64 | ho = torch.cat(out, dim=-1) 65 | 66 | ho = self.mlp(ho) 67 | 68 | return ho 69 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/fedformer_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .fedformer_arch import FEDformer 2 | 3 | __all__ = ["FEDformer"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/fedformer_arch/masking.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class TriangularCausalMask(): 8 | def __init__(self, B, L, device="cpu"): 9 | mask_shape = [B, 1, L, L] 10 | with torch.no_grad(): 11 | self._mask = torch.triu(torch.ones( 12 | mask_shape, dtype=torch.bool), diagonal=1).to(device) 13 | 14 | @property 15 | def mask(self): 16 | return self._mask 17 | 18 | 19 | class ProbMask(): 20 | def __init__(self, B, H, L, index, scores, device="cpu"): 21 | _mask = torch.ones( 22 | L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 23 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 24 | indicator = _mask_ex[torch.arange(B)[:, None, None], 25 | torch.arange(H)[None, :, None], 26 | index, :].to(device) 27 | self._mask = indicator.view(scores.shape).to(device) 28 | 29 | @property 30 | def mask(self): 31 | return self._mask 32 | 33 | 34 | class LocalMask(): 35 | def __init__(self, B, L, S, device="cpu"): 36 | mask_shape = [B, 1, L, S] 37 | with torch.no_grad(): 38 | self.len = math.ceil(np.log2(L)) 39 | self._mask1 = torch.triu(torch.ones( 40 | mask_shape, dtype=torch.bool), diagonal=1).to(device) 41 | self._mask2 = ~torch.triu(torch.ones( 42 | mask_shape, dtype=torch.bool), diagonal=-self.len).to(device) 43 | self._mask = self._mask1+self._mask2 44 | 45 | @property 46 | def mask(self): 47 | return self._mask 48 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/gts_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .gts_arch import GTS 2 | 3 | __all__ = ["GTS"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/gwnet_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .gwnet_arch import GraphWaveNet 2 | 3 | __all__ = ["GraphWaveNet"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/hi_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .hi_arch import HINetwork 2 | 3 | __all__ = ["HINetwork"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/hi_arch/hi_arch.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class HINetwork(nn.Module): 9 | """ 10 | Paper: Historical Inertia: A Neglected but Powerful Baseline for Long Sequence Time-series Forecasting 11 | Link: https://arxiv.org/abs/2103.16349 12 | """ 13 | 14 | def __init__(self, input_length: int, output_length: int, channel=None, reverse=False): 15 | """ 16 | Init HI. 17 | 18 | Args: 19 | input_length (int): input time series length 20 | output_length (int): prediction time series length 21 | channel (list, optional): selected channels. Defaults to None. 22 | reverse (bool, optional): if reverse the prediction of HI. Defaults to False. 23 | """ 24 | 25 | super(HINetwork, self).__init__() 26 | assert input_length >= output_length, "HI model requires input length > output length" 27 | self.input_length = input_length 28 | self.output_length = output_length 29 | self.channel = channel 30 | self.reverse = reverse 31 | self.fake_param = nn.Linear(1, 1) 32 | 33 | def forward(self, history_data: torch.Tensor, **kwargs) -> torch.Tensor: 34 | """Forward function of HI. 35 | 36 | Args: 37 | history_data (torch.Tensor): shape = [B, L_in, N, C] 38 | 39 | Returns: 40 | torch.Tensor: model prediction [B, L_out, N, C]. 41 | """ 42 | 43 | B, L_in, N, C = history_data.shape 44 | assert self.input_length == L_in, 'error input length' 45 | if self.channel is not None: 46 | history_data = history_data[..., self.channel] 47 | # historical inertia 48 | prediction = history_data[:, -self.output_length:, :, :] 49 | # last point 50 | # prediction = history_data[:, [-1], :, :].expand(-1, self.output_length, -1, -1) 51 | if self.reverse: 52 | prediction = prediction.flip(dims=[1]) 53 | return prediction 54 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/informer_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .informer_arch import Informer, InformerStack 2 | 3 | __all__ = ["Informer", "InformerStack"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/informer_arch/decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class DecoderLayer(nn.Module): 6 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, 7 | dropout=0.1, activation="relu"): 8 | super(DecoderLayer, self).__init__() 9 | d_ff = d_ff or 4*d_model 10 | self.self_attention = self_attention 11 | self.cross_attention = cross_attention 12 | self.conv1 = nn.Conv1d(in_channels=d_model, 13 | out_channels=d_ff, kernel_size=1) 14 | self.conv2 = nn.Conv1d( 15 | in_channels=d_ff, out_channels=d_model, kernel_size=1) 16 | self.norm1 = nn.LayerNorm(d_model) 17 | self.norm2 = nn.LayerNorm(d_model) 18 | self.norm3 = nn.LayerNorm(d_model) 19 | self.dropout = nn.Dropout(dropout) 20 | self.activation = F.relu if activation == "relu" else F.gelu 21 | 22 | def forward(self, x, cross, x_mask=None, cross_mask=None): 23 | x = x + self.dropout(self.self_attention( 24 | x, x, x, 25 | attn_mask=x_mask 26 | )[0]) 27 | x = self.norm1(x) 28 | 29 | x = x + self.dropout(self.cross_attention( 30 | x, cross, cross, 31 | attn_mask=cross_mask 32 | )[0]) 33 | 34 | y = x = self.norm2(x) 35 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 36 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 37 | 38 | return self.norm3(x+y) 39 | 40 | 41 | class Decoder(nn.Module): 42 | def __init__(self, layers, norm_layer=None): 43 | super(Decoder, self).__init__() 44 | self.layers = nn.ModuleList(layers) 45 | self.norm = norm_layer 46 | 47 | def forward(self, x, cross, x_mask=None, cross_mask=None): 48 | for layer in self.layers: 49 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) 50 | 51 | if self.norm is not None: 52 | x = self.norm(x) 53 | 54 | return x 55 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/informer_arch/embed.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class PositionalEmbedding(nn.Module): 8 | def __init__(self, d_model, max_len=5000): 9 | super(PositionalEmbedding, self).__init__() 10 | # Compute the positional encodings once in log space. 11 | pe = torch.zeros(max_len, d_model).float() 12 | pe.require_grad = False 13 | 14 | position = torch.arange(0, max_len).float().unsqueeze(1) 15 | div_term = (torch.arange(0, d_model, 2).float() 16 | * -(math.log(10000.0) / d_model)).exp() 17 | 18 | pe[:, 0::2] = torch.sin(position * div_term) 19 | pe[:, 1::2] = torch.cos(position * div_term) 20 | 21 | pe = pe.unsqueeze(0) 22 | self.register_buffer('pe', pe) 23 | 24 | def forward(self, x): 25 | return self.pe[:, :x.size(1)] 26 | 27 | 28 | class TokenEmbedding(nn.Module): 29 | def __init__(self, c_in, d_model): 30 | super(TokenEmbedding, self).__init__() 31 | padding = 1 if torch.__version__ >= '1.5.0' else 2 32 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, 33 | kernel_size=3, padding=padding, padding_mode='circular') 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv1d): 36 | nn.init.kaiming_normal_( 37 | m.weight, mode='fan_in', nonlinearity='leaky_relu') 38 | 39 | def forward(self, x): 40 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) 41 | return x 42 | 43 | 44 | class FixedEmbedding(nn.Module): 45 | def __init__(self, c_in, d_model): 46 | super(FixedEmbedding, self).__init__() 47 | 48 | w = torch.zeros(c_in, d_model).float() 49 | w.require_grad = False 50 | 51 | position = torch.arange(0, c_in).float().unsqueeze(1) 52 | div_term = (torch.arange(0, d_model, 2).float() 53 | * -(math.log(10000.0) / d_model)).exp() 54 | 55 | w[:, 0::2] = torch.sin(position * div_term) 56 | w[:, 1::2] = torch.cos(position * div_term) 57 | 58 | self.emb = nn.Embedding(c_in, d_model) 59 | self.emb.weight = nn.Parameter(w, requires_grad=False) 60 | 61 | def forward(self, x): 62 | return self.emb(x).detach() 63 | 64 | class TemporalEmbedding(nn.Module): 65 | def __init__(self, d_model, time_of_day_size, day_of_week_size, day_of_month_size, day_of_year_size, embed_type='fixed'): 66 | super(TemporalEmbedding, self).__init__() 67 | 68 | Embed = FixedEmbedding if embed_type=='fixed' else nn.Embedding 69 | self.time_of_day_embed = Embed(time_of_day_size, d_model) 70 | self.day_of_week_embed = Embed(day_of_week_size, d_model) 71 | self.day_of_month_embed = Embed(day_of_month_size, d_model) 72 | self.day_of_year_embed = Embed(day_of_year_size, d_model) 73 | 74 | def forward(self, x): 75 | x = x.long() 76 | 77 | time_of_day_x = self.time_of_day_embed(x[:, :, 0]) 78 | day_of_week_x = self.day_of_week_embed(x[:, :, 1]) 79 | day_of_month_x = self.day_of_month_embed(x[:, :, 2]) 80 | day_of_year_x = self.day_of_year_embed(x[:, :, 3]) 81 | 82 | return time_of_day_x + day_of_week_x + day_of_month_x + day_of_year_x 83 | 84 | 85 | class TimeFeatureEmbedding(nn.Module): 86 | def __init__(self, d_model, num_time_features): 87 | super(TimeFeatureEmbedding, self).__init__() 88 | self.embed = nn.Linear(num_time_features, d_model, bias=False) 89 | 90 | def forward(self, x): 91 | return self.embed(x) 92 | 93 | 94 | class DataEmbedding(nn.Module): 95 | def __init__(self, c_in, d_model, time_of_day_size, day_of_week_size, day_of_month_size, day_of_year_size, embed_type='fixed', num_time_features=-1, dropout=0.1): 96 | super(DataEmbedding, self).__init__() 97 | 98 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 99 | self.position_embedding = PositionalEmbedding(d_model=d_model) 100 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, 101 | time_of_day_size=time_of_day_size, 102 | day_of_week_size=day_of_week_size, 103 | day_of_month_size=day_of_month_size, 104 | day_of_year_size=day_of_year_size, 105 | embed_type=embed_type) if embed_type!='timeF' else TimeFeatureEmbedding(d_model=d_model, num_time_features=num_time_features) 106 | 107 | self.dropout = nn.Dropout(p=dropout) 108 | 109 | def forward(self, x, x_mark): 110 | x = self.value_embedding(x) + self.position_embedding(x) + self.temporal_embedding(x_mark) 111 | 112 | return self.dropout(x) 113 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/informer_arch/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvLayer(nn.Module): 7 | def __init__(self, c_in): 8 | super(ConvLayer, self).__init__() 9 | padding = 1 if torch.__version__ >= '1.5.0' else 2 10 | self.downConv = nn.Conv1d(in_channels=c_in, 11 | out_channels=c_in, 12 | kernel_size=3, 13 | padding=padding, 14 | padding_mode='circular') 15 | self.norm = nn.BatchNorm1d(c_in) 16 | self.activation = nn.ELU() 17 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 18 | 19 | def forward(self, x): 20 | x = self.downConv(x.permute(0, 2, 1)) 21 | x = self.norm(x) 22 | x = self.activation(x) 23 | x = self.maxPool(x) 24 | x = x.transpose(1, 2) 25 | return x 26 | 27 | 28 | class EncoderLayer(nn.Module): 29 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 30 | super(EncoderLayer, self).__init__() 31 | d_ff = d_ff or 4*d_model 32 | self.attention = attention 33 | self.conv1 = nn.Conv1d(in_channels=d_model, 34 | out_channels=d_ff, kernel_size=1) 35 | self.conv2 = nn.Conv1d( 36 | in_channels=d_ff, out_channels=d_model, kernel_size=1) 37 | self.norm1 = nn.LayerNorm(d_model) 38 | self.norm2 = nn.LayerNorm(d_model) 39 | self.dropout = nn.Dropout(dropout) 40 | self.activation = F.relu if activation == "relu" else F.gelu 41 | 42 | def forward(self, x, attn_mask=None): 43 | # x [B, L, D] 44 | # x = x + self.dropout(self.attention( 45 | # x, x, x, 46 | # attn_mask = attn_mask 47 | # )) 48 | new_x, attn = self.attention( 49 | x, x, x, 50 | attn_mask=attn_mask 51 | ) 52 | x = x + self.dropout(new_x) 53 | 54 | y = x = self.norm1(x) 55 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 56 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 57 | 58 | return self.norm2(x+y), attn 59 | 60 | 61 | class Encoder(nn.Module): 62 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 63 | super(Encoder, self).__init__() 64 | self.attn_layers = nn.ModuleList(attn_layers) 65 | self.conv_layers = nn.ModuleList( 66 | conv_layers) if conv_layers is not None else None 67 | self.norm = norm_layer 68 | 69 | def forward(self, x, attn_mask=None): 70 | # x [B, L, D] 71 | attns = [] 72 | if self.conv_layers is not None: 73 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): 74 | x, attn = attn_layer(x, attn_mask=attn_mask) 75 | x = conv_layer(x) 76 | attns.append(attn) 77 | x, attn = self.attn_layers[-1](x, attn_mask=attn_mask) 78 | attns.append(attn) 79 | else: 80 | for attn_layer in self.attn_layers: 81 | x, attn = attn_layer(x, attn_mask=attn_mask) 82 | attns.append(attn) 83 | 84 | if self.norm is not None: 85 | x = self.norm(x) 86 | 87 | return x, attns 88 | 89 | 90 | class EncoderStack(nn.Module): 91 | def __init__(self, encoders, inp_lens): 92 | super(EncoderStack, self).__init__() 93 | self.encoders = nn.ModuleList(encoders) 94 | self.inp_lens = inp_lens 95 | 96 | def forward(self, x, attn_mask=None): 97 | # x [B, L, D] 98 | x_stack = [] 99 | attns = [] 100 | for i_len, encoder in zip(self.inp_lens, self.encoders): 101 | inp_len = x.shape[1]//(2**i_len) 102 | x_s, attn = encoder(x[:, -inp_len:, :]) 103 | x_stack.append(x_s) 104 | attns.append(attn) 105 | x_stack = torch.cat(x_stack, -2) 106 | 107 | return x_stack, attns 108 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/informer_arch/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class TriangularCausalMask(): 4 | def __init__(self, B, L, device="cpu"): 5 | mask_shape = [B, 1, L, L] 6 | with torch.no_grad(): 7 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 8 | 9 | @property 10 | def mask(self): 11 | return self._mask 12 | 13 | class ProbMask(): 14 | def __init__(self, B, H, L, index, scores, device="cpu"): 15 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 16 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 17 | indicator = _mask_ex[torch.arange(B)[:, None, None], 18 | torch.arange(H)[None, :, None], 19 | index, :].to(device) 20 | self._mask = indicator.view(scores.shape).to(device) 21 | 22 | @property 23 | def mask(self): 24 | return self._mask -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/linear_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .linear import Linear 2 | from .dlinear import DLinear 3 | from .nlinear import NLinear 4 | 5 | __all__ = ["Linear", "DLinear", "NLinear"] 6 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/linear_arch/dlinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class moving_avg(nn.Module): 6 | """Moving average block to highlight the trend of time series""" 7 | 8 | def __init__(self, kernel_size, stride): 9 | super(moving_avg, self).__init__() 10 | self.kernel_size = kernel_size 11 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, 12 | stride=stride, padding=0) 13 | 14 | def forward(self, x): 15 | # padding on the both ends of time series 16 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) 17 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) 18 | x = torch.cat([front, x, end], dim=1) 19 | x = self.avg(x.permute(0, 2, 1)) 20 | x = x.permute(0, 2, 1) 21 | return x 22 | 23 | 24 | class series_decomp(nn.Module): 25 | """Series decomposition block""" 26 | 27 | def __init__(self, kernel_size): 28 | super(series_decomp, self).__init__() 29 | self.moving_avg = moving_avg(kernel_size, stride=1) 30 | 31 | def forward(self, x): 32 | moving_mean = self.moving_avg(x) 33 | res = x - moving_mean 34 | return res, moving_mean 35 | 36 | 37 | class DLinear(nn.Module): 38 | """ 39 | The implementation of the decomposition-linear model in Paper "Are Transformers Effective for Time Series Forecasting?" 40 | Link: https://arxiv.org/abs/2205.13504 41 | """ 42 | 43 | def __init__(self, **model_args): 44 | super(DLinear, self).__init__() 45 | self.seq_len = model_args["seq_len"] 46 | self.pred_len = model_args["pred_len"] 47 | 48 | # Decompsition Kernel Size 49 | kernel_size = 25 50 | self.decompsition = series_decomp(kernel_size) 51 | self.individual = model_args["individual"] 52 | self.channels = model_args["enc_in"] 53 | 54 | if self.individual: 55 | self.Linear_Seasonal = nn.ModuleList() 56 | self.Linear_Trend = nn.ModuleList() 57 | 58 | for i in range(self.channels): 59 | self.Linear_Seasonal.append( 60 | nn.Linear(self.seq_len, self.pred_len)) 61 | self.Linear_Trend.append( 62 | nn.Linear(self.seq_len, self.pred_len)) 63 | 64 | else: 65 | self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len) 66 | self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len) 67 | 68 | def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_seen: int, epoch: int, train: bool, **kwargs) -> torch.Tensor: 69 | """Feed forward of STID. 70 | 71 | Args: 72 | history_data (torch.Tensor): history data with shape [B, L, N, C] 73 | 74 | Returns: 75 | torch.Tensor: prediction wit shape [B, L, N, C] 76 | """ 77 | 78 | assert history_data.shape[-1] == 1 # only use the target feature 79 | x = history_data[..., 0] # B, L, N 80 | seasonal_init, trend_init = self.decompsition(x) 81 | seasonal_init, trend_init = seasonal_init.permute( 82 | 0, 2, 1), trend_init.permute(0, 2, 1) 83 | if self.individual: 84 | seasonal_output = torch.zeros([seasonal_init.size(0), seasonal_init.size( 85 | 1), self.pred_len], dtype=seasonal_init.dtype).to(seasonal_init.device) 86 | trend_output = torch.zeros([trend_init.size(0), trend_init.size( 87 | 1), self.pred_len], dtype=trend_init.dtype).to(trend_init.device) 88 | for i in range(self.channels): 89 | seasonal_output[:, i, :] = self.Linear_Seasonal[i]( 90 | seasonal_init[:, i, :]) 91 | trend_output[:, i, :] = self.Linear_Trend[i]( 92 | trend_init[:, i, :]) 93 | else: 94 | seasonal_output = self.Linear_Seasonal(seasonal_init) 95 | trend_output = self.Linear_Trend(trend_init) 96 | 97 | prediction = seasonal_output + trend_output 98 | return prediction.permute(0, 2, 1).unsqueeze(-1) # [B, L, N, 1] 99 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/linear_arch/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Linear(nn.Module): 5 | """ 6 | The implementation of the linear model in Paper "Are Transformers Effective for Time Series Forecasting?" 7 | Link: https://arxiv.org/abs/2205.13504 8 | """ 9 | 10 | def __init__(self, **model_args): 11 | super(Linear, self).__init__() 12 | self.seq_len = model_args["seq_len"] 13 | self.pred_len = model_args["pred_len"] 14 | self.Linear = nn.Linear(self.seq_len, self.pred_len) 15 | 16 | def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_seen: int, epoch: int, train: bool, **kwargs) -> torch.Tensor: 17 | """Feed forward of STID. 18 | 19 | Args: 20 | history_data (torch.Tensor): history data with shape [B, L, N, C] 21 | 22 | Returns: 23 | torch.Tensor: prediction wit shape [B, L, N, C] 24 | """ 25 | 26 | assert history_data.shape[-1] == 1 # only use the target feature 27 | history_data = history_data[..., 0] # B, L, N 28 | prediction = self.Linear(history_data.permute(0, 2, 1)).permute(0, 2, 1).unsqueeze(-1) # B, L, N, 1 29 | return prediction 30 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/linear_arch/nlinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class NLinear(nn.Module): 5 | """ 6 | The implementation of the normalization-linear model in Paper "Are Transformers Effective for Time Series Forecasting?" 7 | Link: https://arxiv.org/abs/2205.13504 8 | """ 9 | 10 | def __init__(self, **model_args): 11 | super(NLinear, self).__init__() 12 | self.seq_len = model_args["seq_len"] 13 | self.pred_len = model_args["pred_len"] 14 | self.Linear = nn.Linear(self.seq_len, self.pred_len) 15 | 16 | def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_seen: int, epoch: int, train: bool, **kwargs) -> torch.Tensor: 17 | """Feed forward of STID. 18 | 19 | Args: 20 | history_data (torch.Tensor): history data with shape [B, L, N, C] 21 | 22 | Returns: 23 | torch.Tensor: prediction wit shape [B, L, N, C] 24 | """ 25 | assert history_data.shape[-1] == 1 # only use the target feature 26 | x = history_data[..., 0] # B, L, N 27 | # x: [Batch, Input length, Channel] 28 | seq_last = x[:,-1:,:].detach() 29 | x = x - seq_last 30 | x = self.Linear(x.permute(0,2,1)).permute(0,2,1) 31 | prediction = x + seq_last 32 | return prediction.unsqueeze(-1) 33 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/megacrn/__init__.py: -------------------------------------------------------------------------------- 1 | from .megacrn_arch import MegaCRN -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/mtgnn_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .mtgnn_arch import MTGNN 2 | 3 | __all__ = ["MTGNN"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/pyraformer_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .pyraformer_arch import Pyraformer 2 | 3 | __all__ = ["Pyraformer"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/pyraformer_arch/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ScaledDotProductAttention(nn.Module): 7 | """ Scaled Dot-Product Attention """ 8 | 9 | def __init__(self, temperature, attn_dropout=0.2): 10 | super().__init__() 11 | 12 | self.temperature = temperature 13 | self.dropout = nn.Dropout(attn_dropout) 14 | 15 | def forward(self, q, k, v, mask=None): 16 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 17 | 18 | if mask is not None: 19 | attn = attn.masked_fill(mask, -1e9) 20 | 21 | attn = self.dropout(F.softmax(attn, dim=-1)) 22 | output = torch.matmul(attn, v) 23 | 24 | return output, attn 25 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/pyraformer_arch/pam_tvm.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .hierarchical_mm_tvm import graph_mm as graph_mm_tvm 7 | 8 | 9 | class PyramidalAttention(nn.Module): 10 | def __init__(self, n_head, d_model, d_k, d_v, dropout, normalize_before, q_k_mask, k_q_mask): 11 | super(PyramidalAttention, self).__init__() 12 | self.normalize_before = normalize_before 13 | self.n_head = n_head 14 | self.d_k = d_k 15 | 16 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 17 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 18 | self.w_vs = nn.Linear(d_model, n_head * d_k, bias=False) 19 | nn.init.xavier_uniform_(self.w_qs.weight) 20 | nn.init.xavier_uniform_(self.w_ks.weight) 21 | nn.init.xavier_uniform_(self.w_vs.weight) 22 | 23 | self.fc = nn.Linear(d_k * n_head, d_model) 24 | nn.init.xavier_uniform_(self.fc.weight) 25 | 26 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 27 | self.dropout_attn = nn.Dropout(dropout) 28 | self.dropout_fc = nn.Dropout(dropout) 29 | self.q_k_mask = q_k_mask 30 | self.k_q_mask = k_q_mask 31 | 32 | def forward(self, hidden_states): 33 | residual = hidden_states 34 | 35 | hidden_states = hidden_states 36 | bsz, seq_len, _ = hidden_states.size() 37 | 38 | q = hidden_states 39 | if self.normalize_before: 40 | q = self.layer_norm(q) 41 | 42 | q = self.w_qs(q) 43 | k = self.w_ks(hidden_states) 44 | v = self.w_vs(hidden_states) 45 | q /= math.sqrt(self.d_k) 46 | 47 | q = q.view(bsz, seq_len, self.n_head, self.d_k) 48 | k = k.view(bsz, seq_len, self.n_head, self.d_k) 49 | q = q.float().contiguous() 50 | k = k.float().contiguous() 51 | # attn_weights.size(): (batch_size, L, num_heads, 11) 52 | attn_weights = graph_mm_tvm(q, k, self.q_k_mask, self.k_q_mask, False, -1000000000) 53 | attn_weights = self.dropout_attn(F.softmax(attn_weights, dim=-1)) 54 | 55 | v = v.view(bsz, seq_len, self.n_head, self.d_k) 56 | v = v.float().contiguous() 57 | # is_t1_diagonaled=True 58 | attn = graph_mm_tvm(attn_weights, v, self.q_k_mask, self.k_q_mask, True, 0) 59 | attn = attn.reshape(bsz, seq_len, self.n_head * self.d_k).contiguous() 60 | context = self.dropout_fc(self.fc(attn)) 61 | context += residual 62 | 63 | if not self.normalize_before: 64 | context = self.layer_norm(context) 65 | 66 | return context 67 | 68 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/pyraformer_arch/sub_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from .modules import ScaledDotProductAttention 5 | 6 | 7 | class MultiHeadAttention(nn.Module): 8 | """ Multi-Head Attention module """ 9 | 10 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1, normalize_before=True): 11 | super().__init__() 12 | 13 | self.normalize_before = normalize_before 14 | self.n_head = n_head 15 | self.d_k = d_k 16 | self.d_v = d_v 17 | 18 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 19 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 20 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 21 | nn.init.xavier_uniform_(self.w_qs.weight) 22 | nn.init.xavier_uniform_(self.w_ks.weight) 23 | nn.init.xavier_uniform_(self.w_vs.weight) 24 | 25 | self.fc = nn.Linear(d_v * n_head, d_model) 26 | nn.init.xavier_uniform_(self.fc.weight) 27 | 28 | self.attention = ScaledDotProductAttention( 29 | temperature=d_k ** 0.5, attn_dropout=dropout) 30 | 31 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 32 | self.dropout = nn.Dropout(dropout) 33 | 34 | def forward(self, q, k, v, mask=None): 35 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 36 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 37 | 38 | residual = q 39 | if self.normalize_before: 40 | q = self.layer_norm(q) 41 | 42 | # Pass through the pre-attention projection: b x lq x (n*dv) 43 | # Separate different heads: b x lq x n x dv 44 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 45 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 46 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 47 | 48 | # Transpose for attention dot product: b x n x lq x dv 49 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 50 | 51 | if mask is not None: 52 | if len(mask.size()) == 3: 53 | mask = mask.unsqueeze(1) # For head axis broadcasting. 54 | 55 | output, attn = self.attention(q, k, v, mask=mask) 56 | 57 | # Transpose to move the head dimension back: b x lq x n x dv 58 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 59 | output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 60 | output = self.dropout(self.fc(output)) 61 | output += residual 62 | 63 | if not self.normalize_before: 64 | output = self.layer_norm(output) 65 | return output, attn 66 | 67 | 68 | class PositionwiseFeedForward(nn.Module): 69 | """ Two-layer position-wise feed-forward neural network. """ 70 | 71 | def __init__(self, d_in, d_hid, dropout=0.1, normalize_before=True): 72 | super().__init__() 73 | 74 | self.normalize_before = normalize_before 75 | 76 | self.w_1 = nn.Linear(d_in, d_hid) 77 | self.w_2 = nn.Linear(d_hid, d_in) 78 | 79 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 80 | #self.layer_norm = GraphNorm(d_in) 81 | self.dropout = nn.Dropout(dropout) 82 | 83 | def forward(self, x): 84 | residual = x 85 | if self.normalize_before: 86 | x = self.layer_norm(x) 87 | 88 | x = F.gelu(self.w_1(x)) 89 | x = self.dropout(x) 90 | x = self.w_2(x) 91 | x = self.dropout(x) 92 | x = x + residual 93 | 94 | if not self.normalize_before: 95 | x = self.layer_norm(x) 96 | return x 97 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/stemgnn_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .stemgnn_arch import StemGNN 2 | 3 | __all__ = ["StemGNN"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/stgcn_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .stgcn_arch import STGCNChebGraphConv as STGCN 2 | 3 | __all__ = ["STGCN"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/stgcn_arch/stgcn_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .stgcn_layers import STConvBlock, OutputBlock 5 | 6 | 7 | class STGCNChebGraphConv(nn.Module): 8 | """ 9 | Paper: Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting 10 | Official Code: https://github.com/VeritasYin/STGCN_IJCAI-18 (tensorflow) 11 | Ref Code: https://github.com/hazdzz/STGCN 12 | Note: 13 | https://github.com/hazdzz/STGCN/issues/9 14 | Link: https://arxiv.org/abs/1709.04875 15 | """ 16 | 17 | # STGCNChebGraphConv contains 'TGTND TGTND TNFF' structure 18 | # ChebGraphConv is the graph convolution from ChebyNet. 19 | # Using the Chebyshev polynomials of the first kind as a graph filter. 20 | 21 | # T: Gated Temporal Convolution Layer (GLU or GTU) 22 | # G: Graph Convolution Layer (ChebGraphConv) 23 | # T: Gated Temporal Convolution Layer (GLU or GTU) 24 | # N: Layer Normolization 25 | # D: Dropout 26 | 27 | # T: Gated Temporal Convolution Layer (GLU or GTU) 28 | # G: Graph Convolution Layer (ChebGraphConv) 29 | # T: Gated Temporal Convolution Layer (GLU or GTU) 30 | # N: Layer Normolization 31 | # D: Dropout 32 | 33 | # T: Gated Temporal Convolution Layer (GLU or GTU) 34 | # N: Layer Normalization 35 | # F: Fully-Connected Layer 36 | # F: Fully-Connected Layer 37 | 38 | def __init__(self, Kt, Ks, blocks, T, n_vertex, act_func, graph_conv_type, gso, bias, droprate): 39 | super(STGCNChebGraphConv, self).__init__() 40 | modules = [] 41 | for l in range(len(blocks) - 3): 42 | modules.append(STConvBlock( 43 | Kt, Ks, n_vertex, blocks[l][-1], blocks[l+1], act_func, graph_conv_type, gso, bias, droprate)) 44 | self.st_blocks = nn.Sequential(*modules) 45 | Ko = T - (len(blocks) - 3) * 2 * (Kt - 1) 46 | self.Ko = Ko 47 | assert Ko != 0, "Ko = 0." 48 | self.output = OutputBlock( 49 | Ko, blocks[-3][-1], blocks[-2], blocks[-1][0], n_vertex, act_func, bias, droprate) 50 | 51 | def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_seen: int, epoch: int, train: bool, **kwargs) -> torch.Tensor: 52 | """feedforward function of STGCN. 53 | 54 | Args: 55 | history_data (torch.Tensor): historical data with shape [B, L, N, C] 56 | 57 | Returns: 58 | torch.Tensor: prediction with shape [B, L, N, C] 59 | """ 60 | x = history_data.permute(0, 3, 1, 2).contiguous() 61 | 62 | x = self.st_blocks(x) 63 | x = self.output(x) 64 | 65 | x = x.transpose(2, 3) 66 | return x 67 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/stid_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .stid_arch import STID 2 | 3 | __all__ = ["STID"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/stid_arch/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class MultiLayerPerceptron(nn.Module): 6 | """Multi-Layer Perceptron with residual links.""" 7 | 8 | def __init__(self, input_dim, hidden_dim) -> None: 9 | super().__init__() 10 | self.fc1 = nn.Conv2d( 11 | in_channels=input_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True) 12 | self.fc2 = nn.Conv2d( 13 | in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True) 14 | self.act = nn.ReLU() 15 | self.drop = nn.Dropout(p=0.15) 16 | 17 | def forward(self, input_data: torch.Tensor) -> torch.Tensor: 18 | """Feed forward of MLP. 19 | 20 | Args: 21 | input_data (torch.Tensor): input data with shape [B, D, N] 22 | 23 | Returns: 24 | torch.Tensor: latent repr 25 | """ 26 | 27 | hidden = self.fc2(self.drop(self.act(self.fc1(input_data)))) # MLP 28 | hidden = hidden + input_data # residual 29 | return hidden 30 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/stnorm_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .stnorm_arch import STNorm 2 | 3 | __all__ = ["STNorm"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .xformer import data_transformation_4_xformer 2 | 3 | __all__ = ["data_transformation_4_xformer"] 4 | -------------------------------------------------------------------------------- /basicts/archs/arch_zoo/utils/xformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def date_normalize(data: torch.Tensor, time_of_day_size: int, day_of_week_size: int, day_of_month_size: int, day_of_year_size: int): 4 | """Normalize the date features. 5 | 6 | Args: 7 | data (torch.Tensor): the date features with shape: [B, L, C] 8 | time_of_day_size (int): the size of time in day 9 | day_of_week_size (int): the size of day in week 10 | day_of_month_size (int): the size of day in month 11 | day_of_year_size (int): the size of month in year 12 | """ 13 | # time in day 14 | if time_of_day_size is not None and torch.all(data[:, :, 0].round() == data[:, :, 0]): # not normalized 15 | data[:, :, 0] = data[:, :, 0] / (time_of_day_size - 1) - 0.5 16 | # day in week 17 | if day_of_week_size is not None and torch.all(data[:, :, 1].round() == data[:, :, 1]): 18 | data[:, :, 1] = data[:, :, 1] / (day_of_week_size - 1) - 0.5 19 | # day in month 20 | if day_of_month_size is not None and torch.all(data[:, :, 2].round() == data[:, :, 2]): 21 | data[:, :, 2] = data[:, :, 2] / (day_of_month_size - 1) - 0.5 22 | # month in year 23 | if day_of_year_size is not None and torch.all(data[:, :, 3].round() == data[:, :, 3]): 24 | data[:, :, 3] = data[:, :, 3] / (day_of_year_size - 1) - 0.5 25 | 26 | return data 27 | 28 | def data_transformation_4_xformer(history_data: torch.Tensor, future_data: torch.Tensor, start_token_len: int, 29 | time_of_day_size: int = None, day_of_week_size: int = None, 30 | day_of_month_size: int = None, day_of_year_size:int = None, 31 | embed_type: str= None): 32 | """Transfer the data into the XFormer format. 33 | 34 | Args: 35 | history_data (torch.Tensor): history data with shape: [B, L1, N, C]. 36 | future_data (torch.Tensor): future data with shape: [B, L2, N, C]. 37 | L1 and L2 are input sequence length and output sequence length, respectively. 38 | start_token_length (int): length of the decoder start token. Ref: Informer paper. 39 | 40 | Returns: 41 | torch.Tensor: x_enc, input data of encoder (without the time features). Shape: [B, L1, N] 42 | torch.Tensor: x_mark_enc, time features input of encoder w.r.t. x_enc. Shape: [B, L1, C-1] 43 | torch.Tensor: x_dec, input data of decoder. Shape: [B, start_token_length + L2, N] 44 | torch.Tensor: x_mark_dec, time features input to decoder w.r.t. x_dec. Shape: [B, start_token_length + L2, C-1] 45 | """ 46 | 47 | # get the x_enc 48 | x_enc = history_data[..., 0] # B, L1, N 49 | # get the corresponding x_mark_enc 50 | x_mark_enc = history_data[:, :, 0, 1:] # B, L1, C-1 51 | if embed_type == 'timeF': # use as the time features 52 | x_mark_enc = date_normalize(x_mark_enc, time_of_day_size, day_of_week_size, day_of_month_size, day_of_year_size) 53 | 54 | # get the x_dec 55 | if start_token_len == 0: 56 | x_dec = torch.zeros_like(future_data[..., 0]) # B, L2, N 57 | # get the corresponding x_mark_dec 58 | x_mark_dec = future_data[..., :, 0, 1:] # B, L2, C-1 59 | x_mark_dec = date_normalize(x_mark_dec, time_of_day_size, day_of_week_size, day_of_month_size, day_of_year_size) 60 | return x_enc, x_mark_enc, x_dec, x_mark_dec 61 | else: 62 | x_dec_token = x_enc[:, -start_token_len:, :] # B, start_token_length, N 63 | x_dec_zeros = torch.zeros_like(future_data[..., 0]) # B, L2, N 64 | x_dec = torch.cat([x_dec_token, x_dec_zeros], dim=1) # B, (start_token_length+L2), N 65 | # get the corresponding x_mark_dec 66 | x_mark_dec_token = x_mark_enc[:, -start_token_len:, :] # B, start_token_length, C-1 67 | x_mark_dec_future = future_data[..., :, 0, 1:] # B, L2, C-1 68 | x_mark_dec_future = date_normalize(x_mark_dec_future, time_of_day_size, day_of_week_size, day_of_month_size, day_of_year_size) 69 | x_mark_dec = torch.cat([x_mark_dec_token, x_mark_dec_future], dim=1) # B, (start_token_length+L2), C-1 70 | 71 | return x_enc.float(), x_mark_enc.float(), x_dec.float(), x_mark_dec.float() 72 | -------------------------------------------------------------------------------- /basicts/data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from easytorch.utils.registry import scan_modules 4 | 5 | from .registry import SCALER_REGISTRY 6 | from .dataset import TimeSeriesForecastingDataset 7 | 8 | __all__ = ["SCALER_REGISTRY", "TimeSeriesForecastingDataset"] 9 | 10 | # fix bugs on Windows systems and on jupyter 11 | project_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 12 | scan_modules(project_dir, __file__, ["__init__.py", "registry.py"]) 13 | -------------------------------------------------------------------------------- /basicts/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from ..utils import load_pkl 7 | 8 | 9 | class TimeSeriesForecastingDataset(Dataset): 10 | """Time series forecasting dataset.""" 11 | 12 | def __init__(self, data_file_path: str, index_file_path: str, mode: str) -> None: 13 | super().__init__() 14 | assert mode in ["train", "valid", "test"], "error mode" 15 | self._check_if_file_exists(data_file_path, index_file_path) 16 | # read raw data (normalized) 17 | data = load_pkl(data_file_path) 18 | processed_data = data["processed_data"] 19 | self.data = torch.from_numpy(processed_data).float() 20 | # read index 21 | self.index = load_pkl(index_file_path)[mode] 22 | 23 | def _check_if_file_exists(self, data_file_path: str, index_file_path: str): 24 | """Check if data file and index file exist. 25 | 26 | Args: 27 | data_file_path (str): data file path 28 | index_file_path (str): index file path 29 | 30 | Raises: 31 | FileNotFoundError: no data file 32 | FileNotFoundError: no index file 33 | """ 34 | 35 | if not os.path.isfile(data_file_path): 36 | raise FileNotFoundError("BasicTS can not find data file {0}".format(data_file_path)) 37 | if not os.path.isfile(index_file_path): 38 | raise FileNotFoundError("BasicTS can not find index file {0}".format(index_file_path)) 39 | 40 | def __getitem__(self, index: int) -> tuple: 41 | """Get a sample. 42 | 43 | Args: 44 | index (int): the iteration index (not the self.index) 45 | 46 | Returns: 47 | tuple: (future_data, history_data), where the shape of each is L x N x C. 48 | """ 49 | 50 | idx = list(self.index[index]) 51 | if isinstance(idx[0], int): 52 | # continuous index 53 | history_data = self.data[idx[0]:idx[1]] 54 | future_data = self.data[idx[1]:idx[2]] 55 | else: 56 | # discontinuous index or custom index 57 | # NOTE: current time $t$ should not included in the index[0] 58 | history_index = idx[0] # list 59 | assert idx[1] not in history_index, "current time t should not included in the idx[0]" 60 | history_index.append(idx[1]) 61 | history_data = self.data[history_index] 62 | future_data = self.data[idx[1], idx[2]] 63 | 64 | return future_data, history_data 65 | 66 | def __len__(self): 67 | """Dataset length 68 | 69 | Returns: 70 | int: dataset length 71 | """ 72 | 73 | return len(self.index) 74 | -------------------------------------------------------------------------------- /basicts/data/registry.py: -------------------------------------------------------------------------------- 1 | from easytorch.utils.registry import Registry 2 | 3 | SCALER_REGISTRY = Registry("Scaler") 4 | -------------------------------------------------------------------------------- /basicts/data/transform.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from .registry import SCALER_REGISTRY 7 | 8 | 9 | @SCALER_REGISTRY.register() 10 | def standard_transform(data: np.array, output_dir: str, train_index: list, history_seq_len: int, future_seq_len: int, norm_each_channel: int = False) -> np.array: 11 | """Standard normalization. 12 | 13 | Args: 14 | data (np.array): raw time series data. 15 | output_dir (str): output dir path. 16 | train_index (list): train index. 17 | history_seq_len (int): historical sequence length. 18 | future_seq_len (int): future sequence length. 19 | norm_each_channel (bool): whether to normalization each channel. 20 | 21 | Returns: 22 | np.array: normalized raw time series data. 23 | """ 24 | 25 | # data: L, N, C, C=1 26 | data_train = data[:train_index[-1][1], ...] 27 | if norm_each_channel: 28 | mean, std = data_train.mean(axis=0, keepdims=True), data_train.std(axis=0, keepdims=True) 29 | else: 30 | mean, std = data_train[..., 0].mean(), data_train[..., 0].std() 31 | 32 | print("mean (training data):", mean) 33 | print("std (training data):", std) 34 | scaler = {} 35 | scaler["func"] = re_standard_transform.__name__ 36 | scaler["args"] = {"mean": mean, "std": std} 37 | # label to identify the scaler for different settings. 38 | with open(output_dir + "/scaler_in{0}_out{1}.pkl".format(history_seq_len, future_seq_len), "wb") as f: 39 | pickle.dump(scaler, f) 40 | 41 | def normalize(x): 42 | return (x - mean) / std 43 | 44 | data_norm = normalize(data) 45 | return data_norm 46 | 47 | 48 | @SCALER_REGISTRY.register() 49 | def re_standard_transform(data: torch.Tensor, **kwargs) -> torch.Tensor: 50 | """Standard re-transformation. 51 | 52 | Args: 53 | data (torch.Tensor): input data. 54 | 55 | Returns: 56 | torch.Tensor: re-scaled data. 57 | """ 58 | 59 | mean, std = kwargs["mean"], kwargs["std"] 60 | if isinstance(mean, np.ndarray): 61 | mean = torch.from_numpy(mean).type_as(data).to(data.device).unsqueeze(0) 62 | std = torch.from_numpy(std).type_as(data).to(data.device).unsqueeze(0) 63 | data = data * std 64 | data = data + mean 65 | return data 66 | 67 | 68 | @SCALER_REGISTRY.register() 69 | def min_max_transform(data: np.array, output_dir: str, train_index: list, history_seq_len: int, future_seq_len: int) -> np.array: 70 | """Min-max normalization. 71 | 72 | Args: 73 | data (np.array): raw time series data. 74 | output_dir (str): output dir path. 75 | train_index (list): train index. 76 | history_seq_len (int): historical sequence length. 77 | future_seq_len (int): future sequence length. 78 | 79 | Returns: 80 | np.array: normalized raw time series data. 81 | """ 82 | 83 | # L, N, C, C=1 84 | data_train = data[:train_index[-1][1], ...] 85 | 86 | min_value = data_train.min(axis=(0, 1), keepdims=False)[0] 87 | max_value = data_train.max(axis=(0, 1), keepdims=False)[0] 88 | 89 | print("min: (training data)", min_value) 90 | print("max: (training data)", max_value) 91 | scaler = {} 92 | scaler["func"] = re_min_max_transform.__name__ 93 | scaler["args"] = {"min_value": min_value, "max_value": max_value} 94 | # label to identify the scaler for different settings. 95 | # To be fair, only one transformation can be implemented per dataset. 96 | # TODO: Therefore we (for now) do not distinguish between the data produced by the different transformation methods. 97 | with open(output_dir + "/scaler_in{0}_out{1}.pkl".format(history_seq_len, future_seq_len), "wb") as f: 98 | pickle.dump(scaler, f) 99 | 100 | def normalize(x): 101 | # ref: 102 | # https://github.com/guoshnBJTU/ASTGNN/blob/f0f8c2f42f76cc3a03ea26f233de5961c79c9037/lib/utils.py#L17 103 | x = 1. * (x - min_value) / (max_value - min_value) 104 | x = 2. * x - 1. 105 | return x 106 | 107 | data_norm = normalize(data) 108 | return data_norm 109 | 110 | 111 | @SCALER_REGISTRY.register() 112 | def re_min_max_transform(data: torch.Tensor, **kwargs) -> torch.Tensor: 113 | """Standard re-min-max transform. 114 | 115 | Args: 116 | data (torch.Tensor): input data. 117 | 118 | Returns: 119 | torch.Tensor: re-scaled data. 120 | """ 121 | 122 | min_value, max_value = kwargs["min_value"], kwargs["max_value"] 123 | # ref: 124 | # https://github.com/guoshnBJTU/ASTGNN/blob/f0f8c2f42f76cc3a03ea26f233de5961c79c9037/lib/utils.py#L23 125 | data = (data + 1.) / 2. 126 | data = 1. * data * (max_value - min_value) + min_value 127 | return data 128 | -------------------------------------------------------------------------------- /basicts/launcher.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import easytorch 4 | 5 | def launch_training(cfg: Union[Dict, str], gpus: str = None, node_rank: int = 0): 6 | """Extended easytorch launch_training. 7 | 8 | Args: 9 | cfg (Union[Dict, str]): Easytorch config. 10 | gpus (str): set ``CUDA_VISIBLE_DEVICES`` environment variable. 11 | node_rank (int): Rank of the current node. 12 | """ 13 | 14 | # pre-processing of some possible future features, such as: 15 | # registering model, runners. 16 | # config checking 17 | pass 18 | # launch training based on easytorch 19 | try: 20 | easytorch.launch_training(cfg=cfg, devices=gpus, node_rank=node_rank) 21 | except TypeError as e: 22 | if "launch_training() got an unexpected keyword argument" in repr(e): 23 | # NOTE: for earlier easytorch version 24 | easytorch.launch_training(cfg=cfg, gpus=gpus, node_rank=node_rank) 25 | else: 26 | raise e 27 | -------------------------------------------------------------------------------- /basicts/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import l1_loss, l2_loss 2 | from ..metrics import masked_mae, masked_mape, masked_rmse, masked_mse 3 | 4 | __all__ = ["l1_loss", "l2_loss", "masked_mae", "masked_mape", "masked_rmse", "masked_mse"] 5 | -------------------------------------------------------------------------------- /basicts/losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | from ..utils import check_nan_inf 4 | 5 | 6 | def l1_loss(input_data, target_data, **kwargs): 7 | """unmasked mae.""" 8 | 9 | return F.l1_loss(input_data, target_data) 10 | 11 | 12 | def l2_loss(input_data, target_data, **kwargs): 13 | """unmasked mse""" 14 | 15 | check_nan_inf(input_data) 16 | check_nan_inf(target_data) 17 | return F.mse_loss(input_data, target_data) 18 | -------------------------------------------------------------------------------- /basicts/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .mae import masked_mae 2 | from .mape import masked_mape 3 | from .rmse import masked_rmse, masked_mse 4 | 5 | __all__ = ["masked_mae", "masked_mape", "masked_rmse", "masked_mse"] 6 | -------------------------------------------------------------------------------- /basicts/metrics/mae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def masked_mae(preds: torch.Tensor, labels: torch.Tensor, null_val: float = np.nan) -> torch.Tensor: 6 | """Masked mean absolute error. 7 | 8 | Args: 9 | preds (torch.Tensor): predicted values 10 | labels (torch.Tensor): labels 11 | null_val (float, optional): null value. Defaults to np.nan. 12 | 13 | Returns: 14 | torch.Tensor: masked mean absolute error 15 | """ 16 | 17 | if np.isnan(null_val): 18 | mask = ~torch.isnan(labels) 19 | else: 20 | eps = 5e-5 21 | mask = ~torch.isclose(labels, torch.tensor(null_val).expand_as(labels).to(labels.device), atol=eps, rtol=0.) 22 | mask = mask.float() 23 | mask /= torch.mean((mask)) 24 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 25 | loss = torch.abs(preds-labels) 26 | loss = loss * mask 27 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 28 | return torch.mean(loss) 29 | -------------------------------------------------------------------------------- /basicts/metrics/mape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def masked_mape(preds: torch.Tensor, labels: torch.Tensor, null_val: float = 0.0) -> torch.Tensor: 6 | """Masked mean absolute percentage error. 7 | 8 | Args: 9 | preds (torch.Tensor): predicted values 10 | labels (torch.Tensor): labels 11 | null_val (float, optional): null value. 12 | In the mape metric, null_val is set to 0.0 by all default. 13 | We keep this parameter for consistency, but we do not allow it to be changed. 14 | Zeros in labels will lead to inf in mape. Therefore, null_val is set to 0.0 by default. 15 | 16 | Returns: 17 | torch.Tensor: masked mean absolute percentage error 18 | """ 19 | # we do not allow null_val to be changed 20 | null_val = 0.0 21 | # delete small values to avoid abnormal results 22 | # TODO: support multiple null values 23 | labels = torch.where(torch.abs(labels) < 1e-4, torch.zeros_like(labels), labels) 24 | if np.isnan(null_val): 25 | mask = ~torch.isnan(labels) 26 | else: 27 | eps = 5e-5 28 | mask = ~torch.isclose(labels, torch.tensor(null_val).expand_as(labels).to(labels.device), atol=eps, rtol=0.) 29 | mask = mask.float() 30 | mask /= torch.mean((mask)) 31 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 32 | loss = torch.abs(torch.abs(preds-labels)/labels) 33 | loss = loss * mask 34 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 35 | return torch.mean(loss) 36 | -------------------------------------------------------------------------------- /basicts/metrics/rmse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def masked_mse(preds: torch.Tensor, labels: torch.Tensor, null_val: float = np.nan) -> torch.Tensor: 6 | """Masked mean squared error. 7 | 8 | Args: 9 | preds (torch.Tensor): predicted values 10 | labels (torch.Tensor): labels 11 | null_val (float, optional): null value. Defaults to np.nan. 12 | 13 | Returns: 14 | torch.Tensor: masked mean squared error 15 | """ 16 | 17 | if np.isnan(null_val): 18 | mask = ~torch.isnan(labels) 19 | else: 20 | eps = 5e-5 21 | mask = ~torch.isclose(labels, torch.tensor(null_val).expand_as(labels).to(labels.device), atol=eps, rtol=0.) 22 | mask = mask.float() 23 | mask /= torch.mean((mask)) 24 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 25 | loss = (preds-labels)**2 26 | loss = loss * mask 27 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 28 | return torch.mean(loss) 29 | 30 | 31 | def masked_rmse(preds: torch.Tensor, labels: torch.Tensor, null_val: float = np.nan) -> torch.Tensor: 32 | """root mean squared error. 33 | 34 | Args: 35 | preds (torch.Tensor): predicted values 36 | labels (torch.Tensor): labels 37 | null_val (float, optional): null value . Defaults to np.nan. 38 | 39 | Returns: 40 | torch.Tensor: root mean squared error 41 | """ 42 | 43 | return torch.sqrt(masked_mse(preds=preds, labels=labels, null_val=null_val)) 44 | -------------------------------------------------------------------------------- /basicts/runners/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_tsf_runner import BaseTimeSeriesForecastingRunner 2 | from .runner_zoo.simple_tsf_runner import SimpleTimeSeriesForecastingRunner 3 | from .runner_zoo.stid_runner import STIDRunner 4 | from .runner_zoo.gwnet_runner import GraphWaveNetRunner 5 | from .runner_zoo.dcrnn_runner import DCRNNRunner 6 | from .runner_zoo.d2stgnn_runner import D2STGNNRunner 7 | from .runner_zoo.stgcn_runner import STGCNRunner 8 | from .runner_zoo.mtgnn_runner import MTGNNRunner 9 | from .runner_zoo.stnorm_runner import STNormRunner 10 | from .runner_zoo.agcrn_runner import AGCRNRunner 11 | from .runner_zoo.stemgnn_runner import StemGNNRunner 12 | from .runner_zoo.gts_runner import GTSRunner 13 | from .runner_zoo.dgcrn_runner import DGCRNRunner 14 | from .runner_zoo.linear_runner import LinearRunner 15 | from .runner_zoo.autoformer_runner import AutoformerRunner 16 | from .runner_zoo.hi_runner import HIRunner 17 | from .runner_zoo.fedformer_runner import FEDformerRunner 18 | from .runner_zoo.informer_runner import InformerRunner 19 | from .runner_zoo.pyraformer_runner import PyraformerRunner 20 | from .runner_zoo.megecrn_runner import MegaCRNRunner 21 | 22 | __all__ = ["BaseTimeSeriesForecastingRunner", 23 | "SimpleTimeSeriesForecastingRunner", "STIDRunner", 24 | "GraphWaveNetRunner", "DCRNNRunner", "D2STGNNRunner", 25 | "STGCNRunner", "MTGNNRunner", "STNormRunner", 26 | "AGCRNRunner", "StemGNNRunner", "GTSRunner", 27 | "DGCRNRunner", "LinearRunner", "AutoformerRunner", 28 | "HIRunner", "FEDformerRunner", "InformerRunner", 29 | "PyraformerRunner", "MegaCRNRunner"] 30 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/agcrn_runner.py: -------------------------------------------------------------------------------- 1 | from .simple_tsf_runner import SimpleTimeSeriesForecastingRunner as AGCRNRunner 2 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/autoformer_runner.py: -------------------------------------------------------------------------------- 1 | from .simple_tsf_runner import SimpleTimeSeriesForecastingRunner as AutoformerRunner 2 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/d2stgnn_runner.py: -------------------------------------------------------------------------------- 1 | from .simple_tsf_runner import SimpleTimeSeriesForecastingRunner as D2STGNNRunner 2 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/dcrnn_runner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..base_tsf_runner import BaseTimeSeriesForecastingRunner 4 | 5 | 6 | class DCRNNRunner(BaseTimeSeriesForecastingRunner): 7 | """Runner for DCRNN: add setup_graph and teacher forcing.""" 8 | 9 | def __init__(self, cfg: dict): 10 | super().__init__(cfg) 11 | self.forward_features = cfg["MODEL"].get("FORWARD_FEATURES", None) 12 | self.target_features = cfg["MODEL"].get("TARGET_FEATURES", None) 13 | 14 | def setup_graph(self, data): 15 | """The dcrnn official codes act like tensorflow, which create parameters in the first feedforward process.""" 16 | try: 17 | self.train_iters(1, 0, data) 18 | except AttributeError: 19 | pass 20 | 21 | def select_input_features(self, data: torch.Tensor) -> torch.Tensor: 22 | """Select input features and reshape data to fit the target model. 23 | 24 | Args: 25 | data (torch.Tensor): input history data, shape [B, L, N, C]. 26 | 27 | Returns: 28 | torch.Tensor: reshaped data 29 | """ 30 | 31 | # select feature using self.forward_features 32 | if self.forward_features is not None: 33 | data = data[:, :, :, self.forward_features] 34 | return data 35 | 36 | def select_target_features(self, data: torch.Tensor) -> torch.Tensor: 37 | """Select target features and reshape data back to the BasicTS framework 38 | 39 | Args: 40 | data (torch.Tensor): prediction of the model with arbitrary shape. 41 | 42 | Returns: 43 | torch.Tensor: reshaped data with shape [B, L, N, C] 44 | """ 45 | 46 | # select feature using self.target_features 47 | data = data[:, :, :, self.target_features] 48 | return data 49 | 50 | def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: bool = True) -> tuple: 51 | """Feed forward process for train, val, and test. Note that the outputs are NOT re-scaled. 52 | 53 | Args: 54 | data (tuple): data (future data, history data). [B, L, N, C] for each of them 55 | epoch (int, optional): epoch number. Defaults to None. 56 | iter_num (int, optional): iteration number. Defaults to None. 57 | train (bool, optional): if in the training process. Defaults to True. 58 | 59 | Returns: 60 | tuple: (prediction, real_value) 61 | """ 62 | 63 | # preprocess 64 | future_data, history_data = data 65 | history_data = self.to_running_device(history_data) # B, L, N, C 66 | future_data = self.to_running_device(future_data) # B, L, N, C 67 | batch_size, length, num_nodes, _ = future_data.shape 68 | 69 | history_data = self.select_input_features(history_data) 70 | if train: 71 | # teacher forcing only use the first dimension. 72 | future_data_4_dec = future_data[..., [0]] 73 | else: 74 | future_data_4_dec = None 75 | 76 | # feed forward 77 | prediction_data = self.model(history_data=history_data, future_data=future_data_4_dec, 78 | batch_seen=iter_num if self.model.training else None, epoch=epoch) 79 | assert list(prediction_data.shape)[:3] == [batch_size, length, num_nodes], \ 80 | "error shape of the output, edit the forward function to reshape it to [B, L, N, C]" 81 | # post process 82 | prediction = self.select_target_features(prediction_data) 83 | real_value = self.select_target_features(future_data) 84 | return prediction, real_value 85 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/dgcrn_runner.py: -------------------------------------------------------------------------------- 1 | from .simple_tsf_runner import SimpleTimeSeriesForecastingRunner as DGCRNRunner 2 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/fedformer_runner.py: -------------------------------------------------------------------------------- 1 | from .simple_tsf_runner import SimpleTimeSeriesForecastingRunner as FEDformerRunner 2 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/gts_runner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..base_tsf_runner import BaseTimeSeriesForecastingRunner 4 | 5 | 6 | class GTSRunner(BaseTimeSeriesForecastingRunner): 7 | def __init__(self, cfg: dict): 8 | super().__init__(cfg) 9 | self.forward_features = cfg["MODEL"].get("FORWARD_FEATURES", None) 10 | self.target_features = cfg["MODEL"].get("TARGET_FEATURES", None) 11 | 12 | def setup_graph(self, data): 13 | try: 14 | self.train_iters(1, 0, data) 15 | except AttributeError: 16 | pass 17 | 18 | def select_input_features(self, data: torch.Tensor) -> torch.Tensor: 19 | """Select input features and reshape data to fit the target model. 20 | 21 | Args: 22 | data (torch.Tensor): input history data, shape [B, L, N, C]. 23 | 24 | Returns: 25 | torch.Tensor: reshaped data 26 | """ 27 | 28 | # select feature using self.forward_features 29 | if self.forward_features is not None: 30 | data = data[:, :, :, self.forward_features] 31 | return data 32 | 33 | def select_target_features(self, data: torch.Tensor) -> torch.Tensor: 34 | """Select target features and reshape data back to the BasicTS framework 35 | 36 | Args: 37 | data (torch.Tensor): prediction of the model with arbitrary shape. 38 | 39 | Returns: 40 | torch.Tensor: reshaped data with shape [B, L, N, C] 41 | """ 42 | 43 | # select feature using self.target_features 44 | data = data[:, :, :, self.target_features] 45 | return data 46 | 47 | def forward(self, data: tuple, epoch:int = None, iter_num: int = None, train:bool = True, **kwargs) -> tuple: 48 | """feed forward process for train, val, and test. Note that the outputs are NOT re-scaled. 49 | 50 | Args: 51 | data (tuple): data (future data, history data). [B, L, N, C] for each of them 52 | epoch (int, optional): epoch number. Defaults to None. 53 | iter_num (int, optional): iteration number. Defaults to None. 54 | train (bool, optional): if in the training process. Defaults to True. 55 | 56 | Returns: 57 | tuple: (prediction, real_value) 58 | """ 59 | 60 | # preprocess 61 | future_data, history_data = data 62 | history_data = self.to_running_device(history_data) # B, L, N, C 63 | future_data = self.to_running_device(future_data) # B, L, N, C 64 | batch_size, length, num_nodes, _ = future_data.shape 65 | 66 | history_data = self.select_input_features(history_data) 67 | if train: 68 | # teacher forcing only use the first dimension. 69 | future_data_4_dec = future_data[..., [0]] 70 | else: 71 | future_data_4_dec = None 72 | 73 | # feed forward 74 | prediction_data, pred_adj, prior_adj = self.model(history_data=history_data, future_data=future_data_4_dec, batch_seen=iter_num, epoch=epoch) 75 | assert list(prediction_data.shape)[:3] == [batch_size, length, num_nodes], \ 76 | "error shape of the output, edit the forward function to reshape it to [B, L, N, C]" 77 | # post process 78 | prediction = self.select_target_features(prediction_data) 79 | real_value = self.select_target_features(future_data) 80 | return prediction, real_value, pred_adj, prior_adj 81 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/gwnet_runner.py: -------------------------------------------------------------------------------- 1 | from .simple_tsf_runner import SimpleTimeSeriesForecastingRunner as GraphWaveNetRunner 2 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/hi_runner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .simple_tsf_runner import SimpleTimeSeriesForecastingRunner 4 | 5 | 6 | class HIRunner(SimpleTimeSeriesForecastingRunner): 7 | 8 | def backward(self, loss: torch.Tensor): 9 | pass 10 | return 11 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/informer_runner.py: -------------------------------------------------------------------------------- 1 | from .simple_tsf_runner import SimpleTimeSeriesForecastingRunner as InformerRunner 2 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/linear_runner.py: -------------------------------------------------------------------------------- 1 | from .simple_tsf_runner import SimpleTimeSeriesForecastingRunner as LinearRunner 2 | 3 | __all__ = ["LinearRunner"] 4 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/megecrn_runner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..base_tsf_runner import BaseTimeSeriesForecastingRunner 4 | 5 | 6 | class MegaCRNRunner(BaseTimeSeriesForecastingRunner): 7 | def __init__(self, cfg: dict): 8 | super().__init__(cfg) 9 | self.forward_features = cfg["MODEL"].get("FORWARD_FEATURES", None) 10 | self.target_features = cfg["MODEL"].get("TARGET_FEATURES", None) 11 | 12 | def setup_graph(self, data): 13 | try: 14 | self.train_iters(1, 0, data) 15 | except AttributeError: 16 | pass 17 | 18 | def select_input_features(self, data: torch.Tensor) -> torch.Tensor: 19 | """Select input features and reshape data to fit the target model. 20 | 21 | Args: 22 | data (torch.Tensor): input history data, shape [B, L, N, C]. 23 | 24 | Returns: 25 | torch.Tensor: reshaped data 26 | """ 27 | 28 | # select feature using self.forward_features 29 | if self.forward_features is not None: 30 | data = data[:, :, :, self.forward_features] 31 | return data 32 | 33 | def select_target_features(self, data: torch.Tensor) -> torch.Tensor: 34 | """Select target features and reshape data back to the BasicTS framework 35 | 36 | Args: 37 | data (torch.Tensor): prediction of the model with arbitrary shape. 38 | 39 | Returns: 40 | torch.Tensor: reshaped data with shape [B, L, N, C] 41 | """ 42 | 43 | # select feature using self.target_features 44 | data = data[:, :, :, self.target_features] 45 | return data 46 | 47 | def forward(self, data: tuple, epoch:int = None, iter_num: int = None, train:bool = True, **kwargs) -> tuple: 48 | """feed forward process for train, val, and test. Note that the outputs are NOT re-scaled. 49 | 50 | Args: 51 | data (tuple): data (future data, history data). [B, L, N, C] for each of them 52 | epoch (int, optional): epoch number. Defaults to None. 53 | iter_num (int, optional): iteration number. Defaults to None. 54 | train (bool, optional): if in the training process. Defaults to True. 55 | 56 | Returns: 57 | tuple: (prediction, real_value) 58 | """ 59 | 60 | # preprocess 61 | future_data, history_data = data 62 | history_data = self.to_running_device(history_data) # B, L, N, C 63 | future_data = self.to_running_device(future_data) # B, L, N, C 64 | batch_size, length, num_nodes, _ = future_data.shape 65 | 66 | history_data = self.select_input_features(history_data) 67 | 68 | # feed forward 69 | prediction_data, h_att, query, pos, neg = self.model(history_data=history_data, batch_seen=iter_num, epoch=epoch) 70 | assert list(prediction_data.shape)[:3] == [batch_size, length, num_nodes], \ 71 | "error shape of the output, edit the forward function to reshape it to [B, L, N, C]" 72 | # post process 73 | prediction = self.select_target_features(prediction_data) 74 | real_value = self.select_target_features(future_data) 75 | return prediction, real_value, query, pos, neg 76 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/mtgnn_runner.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from ..base_tsf_runner import BaseTimeSeriesForecastingRunner 7 | 8 | 9 | class MTGNNRunner(BaseTimeSeriesForecastingRunner): 10 | def __init__(self, cfg: dict): 11 | super().__init__(cfg) 12 | self.forward_features = cfg["MODEL"].get("FORWARD_FEATURES", None) 13 | self.target_features = cfg["MODEL"].get("TARGET_FEATURES", None) 14 | # graph training 15 | self.step_size = cfg.TRAIN.CUSTOM.STEP_SIZE 16 | self.num_nodes = cfg.TRAIN.CUSTOM.NUM_NODES 17 | self.num_split = cfg.TRAIN.CUSTOM.NUM_SPLIT 18 | self.perm = None 19 | 20 | def select_input_features(self, data: torch.Tensor) -> torch.Tensor: 21 | """Select input features. 22 | 23 | Args: 24 | data (torch.Tensor): input history data, shape [B, L, N, C] 25 | 26 | Returns: 27 | torch.Tensor: reshaped data 28 | """ 29 | 30 | # select feature using self.forward_features 31 | if self.forward_features is not None: 32 | data = data[:, :, :, self.forward_features] 33 | return data 34 | 35 | def select_target_features(self, data: torch.Tensor) -> torch.Tensor: 36 | """Select target feature 37 | 38 | Args: 39 | data (torch.Tensor): prediction of the model with arbitrary shape. 40 | 41 | Returns: 42 | torch.Tensor: reshaped data with shape [B, L, N, C] 43 | """ 44 | 45 | # select feature using self.target_features 46 | data = data[:, :, :, self.target_features] 47 | return data 48 | 49 | def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: bool = True, **kwargs) -> tuple: 50 | """Feed forward process for train, val, and test. Note that the outputs are NOT re-scaled. 51 | 52 | Args: 53 | data (tuple): data (future data, history data). [B, L, N, C] for each of them 54 | epoch (int, optional): epoch number. Defaults to None. 55 | iter_num (int, optional): iteration number. Defaults to None. 56 | train (bool, optional): if in the training process. Defaults to True. 57 | 58 | Returns: 59 | tuple: (prediction, real_value). [B, L, N, C] for each of them. 60 | """ 61 | 62 | if train: 63 | future_data, history_data, idx = data 64 | else: 65 | future_data, history_data = data 66 | idx = None 67 | 68 | history_data = self.to_running_device(history_data) # B, L, N, C 69 | future_data = self.to_running_device(future_data) # B, L, N, C 70 | batch_size, seq_len, num_nodes, _ = future_data.shape 71 | 72 | history_data = self.select_input_features(history_data) 73 | 74 | prediction_data = self.model( 75 | history_data=history_data, idx=idx, batch_seen=iter_num, epoch=epoch) # B, L, N, C 76 | assert list(prediction_data.shape)[:3] == [ 77 | batch_size, seq_len, num_nodes], "error shape of the output, edit the forward function to reshape it to [B, L, N, C]" 78 | # post process 79 | prediction = self.select_target_features(prediction_data) 80 | real_value = self.select_target_features(future_data) 81 | return prediction, real_value 82 | 83 | def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tuple]) -> torch.Tensor: 84 | """It must be implement to define training detail. 85 | 86 | If it returns `loss`, the function ```self.backward``` will be called. 87 | 88 | Args: 89 | epoch (int): current epoch. 90 | iter_index (int): current iter. 91 | data (torch.Tensor or tuple): Data provided by DataLoader 92 | 93 | Returns: 94 | loss (torch.Tensor) 95 | """ 96 | 97 | if iter_index % self.step_size == 0: 98 | self.perm = np.random.permutation(range(self.num_nodes)) 99 | num_sub = int(self.num_nodes/self.num_split) 100 | for j in range(self.num_split): 101 | if j != self.num_split-1: 102 | idx = self.perm[j * num_sub:(j + 1) * num_sub] 103 | raise 104 | else: 105 | idx = self.perm[j * num_sub:] 106 | idx = torch.tensor(idx) 107 | future_data, history_data = data 108 | data = future_data[:, :, idx, :], history_data[:, :, idx, :], idx 109 | loss = super().train_iters(epoch, iter_index, data) 110 | self.backward(loss) 111 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/pyraformer_runner.py: -------------------------------------------------------------------------------- 1 | from .simple_tsf_runner import SimpleTimeSeriesForecastingRunner as PyraformerRunner 2 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/simple_tsf_runner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..base_tsf_runner import BaseTimeSeriesForecastingRunner 4 | 5 | 6 | class SimpleTimeSeriesForecastingRunner(BaseTimeSeriesForecastingRunner): 7 | """Simple Runner: select forward features and target features.""" 8 | 9 | def __init__(self, cfg: dict): 10 | super().__init__(cfg) 11 | self.forward_features = cfg["MODEL"].get("FORWARD_FEATURES", None) 12 | self.target_features = cfg["MODEL"].get("TARGET_FEATURES", None) 13 | 14 | def select_input_features(self, data: torch.Tensor) -> torch.Tensor: 15 | """Select input features. 16 | 17 | Args: 18 | data (torch.Tensor): input history data, shape [B, L, N, C] 19 | 20 | Returns: 21 | torch.Tensor: reshaped data 22 | """ 23 | 24 | # select feature using self.forward_features 25 | if self.forward_features is not None: 26 | data = data[:, :, :, self.forward_features] 27 | return data 28 | 29 | def select_target_features(self, data: torch.Tensor) -> torch.Tensor: 30 | """Select target feature. 31 | 32 | Args: 33 | data (torch.Tensor): prediction of the model with arbitrary shape. 34 | 35 | Returns: 36 | torch.Tensor: reshaped data with shape [B, L, N, C] 37 | """ 38 | 39 | # select feature using self.target_features 40 | data = data[:, :, :, self.target_features] 41 | return data 42 | 43 | def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: bool = True, **kwargs) -> tuple: 44 | """Feed forward process for train, val, and test. Note that the outputs are NOT re-scaled. 45 | 46 | Args: 47 | data (tuple): data (future data, history ata). 48 | epoch (int, optional): epoch number. Defaults to None. 49 | iter_num (int, optional): iteration number. Defaults to None. 50 | train (bool, optional): if in the training process. Defaults to True. 51 | 52 | Returns: 53 | tuple: (prediction, real_value) 54 | """ 55 | 56 | # preprocess 57 | future_data, history_data = data 58 | history_data = self.to_running_device(history_data) # B, L, N, C 59 | future_data = self.to_running_device(future_data) # B, L, N, C 60 | batch_size, length, num_nodes, _ = future_data.shape 61 | 62 | history_data = self.select_input_features(history_data) 63 | future_data_4_dec = self.select_input_features(future_data) 64 | 65 | # curriculum learning 66 | if self.cl_param is None: 67 | prediction_data = self.model(history_data=history_data, future_data=future_data_4_dec, batch_seen=iter_num, epoch=epoch, train=train) 68 | else: 69 | task_level = self.curriculum_learning(epoch) 70 | prediction_data = self.model(history_data=history_data, future_data=future_data_4_dec, batch_seen=iter_num, epoch=epoch, train=train,\ 71 | task_level=task_level) 72 | # feed forward 73 | assert list(prediction_data.shape)[:3] == [batch_size, length, num_nodes], \ 74 | "error shape of the output, edit the forward function to reshape it to [B, L, N, C]" 75 | # post process 76 | prediction = self.select_target_features(prediction_data) 77 | real_value = self.select_target_features(future_data) 78 | return prediction, real_value 79 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/stemgnn_runner.py: -------------------------------------------------------------------------------- 1 | from .simple_tsf_runner import SimpleTimeSeriesForecastingRunner as StemGNNRunner 2 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/stgcn_runner.py: -------------------------------------------------------------------------------- 1 | from .simple_tsf_runner import SimpleTimeSeriesForecastingRunner as STGCNRunner 2 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/stid_runner.py: -------------------------------------------------------------------------------- 1 | from .simple_tsf_runner import SimpleTimeSeriesForecastingRunner as STIDRunner 2 | -------------------------------------------------------------------------------- /basicts/runners/runner_zoo/stnorm_runner.py: -------------------------------------------------------------------------------- 1 | from .simple_tsf_runner import SimpleTimeSeriesForecastingRunner as STNormRunner 2 | -------------------------------------------------------------------------------- /basicts/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .serialization import load_adj, load_pkl, dump_pkl, load_node2vec_emb 2 | from .misc import clock, check_nan_inf, remove_nan_inf 3 | 4 | __all__ = ["load_adj", "load_pkl", "dump_pkl", "load_node2vec_emb", "clock", "check_nan_inf", "remove_nan_inf"] 5 | -------------------------------------------------------------------------------- /basicts/utils/adjacent_matrix_norm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | from scipy.sparse import linalg 4 | 5 | 6 | def calculate_symmetric_normalized_laplacian(adj: np.ndarray) -> np.matrix: 7 | """Calculate yymmetric normalized laplacian. 8 | Assuming unnormalized laplacian matrix is `L = D - A`, 9 | then symmetric normalized laplacian matrix is: 10 | `L^{Sym} = D^-1/2 L D^-1/2 = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2` 11 | For node `i` and `j` where `i!=j`, L^{sym}_{ij} <=0. 12 | 13 | Args: 14 | adj (np.ndarray): Adjacent matrix A 15 | 16 | Returns: 17 | np.matrix: Symmetric normalized laplacian L^{Sym} 18 | """ 19 | 20 | adj = sp.coo_matrix(adj) 21 | degree = np.array(adj.sum(1)) 22 | # diagonals of D^{-1/2} 23 | degree_inv_sqrt = np.power(degree, -0.5).flatten() 24 | degree_inv_sqrt[np.isinf(degree_inv_sqrt)] = 0. 25 | matrix_degree_inv_sqrt = sp.diags(degree_inv_sqrt) # D^{-1/2} 26 | symmetric_normalized_laplacian = sp.eye( 27 | adj.shape[0]) - matrix_degree_inv_sqrt.dot(adj).dot(matrix_degree_inv_sqrt).tocoo() 28 | return symmetric_normalized_laplacian 29 | 30 | 31 | def calculate_scaled_laplacian(adj: np.ndarray, lambda_max: int = 2, undirected: bool = True) -> np.matrix: 32 | """Re-scaled the eigenvalue to [-1, 1] by scaled the normalized laplacian matrix for chebyshev pol. 33 | According to `2017 ICLR GCN`, the lambda max is set to 2, and the graph is set to undirected. 34 | Note that rescale the laplacian matrix is equal to rescale the eigenvalue matrix. 35 | `L_{scaled} = (2 / lambda_max * L) - I` 36 | 37 | Args: 38 | adj (np.ndarray): Adjacent matrix A 39 | lambda_max (int, optional): Defaults to 2. 40 | undirected (bool, optional): Defaults to True. 41 | 42 | Returns: 43 | np.matrix: The rescaled laplacian matrix. 44 | """ 45 | 46 | if undirected: 47 | adj = np.maximum.reduce([adj, adj.T]) 48 | laplacian_matrix = calculate_symmetric_normalized_laplacian(adj) 49 | if lambda_max is None: # manually cal the max lambda 50 | lambda_max, _ = linalg.eigsh(laplacian_matrix, 1, which='LM') 51 | lambda_max = lambda_max[0] 52 | laplacian_matrix = sp.csr_matrix(laplacian_matrix) 53 | num_nodes, _ = laplacian_matrix.shape 54 | identity_matrix = sp.identity( 55 | num_nodes, format='csr', dtype=laplacian_matrix.dtype) 56 | laplacian_res = (2 / lambda_max * laplacian_matrix) - identity_matrix 57 | return laplacian_res 58 | 59 | 60 | def calculate_symmetric_message_passing_adj(adj: np.ndarray) -> np.matrix: 61 | """Calculate the renormalized message passing adj in `GCN`. 62 | A = A + I 63 | return D^{-1/2} A D^{-1/2} 64 | 65 | Args: 66 | adj (np.ndarray): Adjacent matrix A 67 | 68 | Returns: 69 | np.matrix: Renormalized message passing adj in `GCN`. 70 | """ 71 | 72 | # add self loop 73 | adj = adj + np.diag(np.ones(adj.shape[0], dtype=np.float32)) 74 | # print("calculating the renormalized message passing adj, please ensure that self-loop has added to adj.") 75 | adj = sp.coo_matrix(adj) 76 | row_sum = np.array(adj.sum(1)) 77 | d_inv_sqrt = np.power(row_sum, -0.5).flatten() 78 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 79 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 80 | mp_adj = d_mat_inv_sqrt.dot(adj).transpose().dot( 81 | d_mat_inv_sqrt).astype(np.float32) 82 | return mp_adj 83 | 84 | 85 | def calculate_transition_matrix(adj: np.ndarray) -> np.matrix: 86 | """Calculate the transition matrix `P` proposed in DCRNN and Graph WaveNet. 87 | P = D^{-1}A = A/rowsum(A) 88 | 89 | Args: 90 | adj (np.ndarray): Adjacent matrix A 91 | 92 | Returns: 93 | np.matrix: Transition matrix P 94 | """ 95 | 96 | adj = sp.coo_matrix(adj) 97 | row_sum = np.array(adj.sum(1)).flatten() 98 | d_inv = np.power(row_sum, -1).flatten() 99 | d_inv[np.isinf(d_inv)] = 0. 100 | d_mat = sp.diags(d_inv) 101 | prob_matrix = d_mat.dot(adj).astype(np.float32).todense() 102 | return prob_matrix 103 | -------------------------------------------------------------------------------- /basicts/utils/misc.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | 5 | 6 | def clock(func): 7 | """clock decorator""" 8 | def clocked(*args, **kw): 9 | """decorator for clock""" 10 | t0 = time.perf_counter() 11 | result = func(*args, **kw) 12 | elapsed = time.perf_counter() - t0 13 | name = func.__name__ 14 | print("%s: %0.8fs..." % (name, elapsed)) 15 | return result 16 | return clocked 17 | 18 | 19 | def check_nan_inf(tensor: torch.Tensor, raise_ex: bool = True) -> tuple: 20 | """check nan and in in tensor 21 | 22 | Args: 23 | tensor (torch.Tensor): Tensor 24 | raise_ex (bool, optional): If raise exceptions. Defaults to True. 25 | 26 | Raises: 27 | Exception: If raise_ex is True and there are nans or infs in tensor, then raise Exception. 28 | 29 | Returns: 30 | dict: {'nan': bool, 'inf': bool} 31 | bool: if exist nan or if 32 | """ 33 | 34 | # nan 35 | nan = torch.any(torch.isnan(tensor)) 36 | # inf 37 | inf = torch.any(torch.isinf(tensor)) 38 | # raise 39 | if raise_ex and (nan or inf): 40 | raise Exception({"nan": nan, "inf": inf}) 41 | return {"nan": nan, "inf": inf}, nan or inf 42 | 43 | 44 | def remove_nan_inf(tensor: torch.Tensor): 45 | """remove nan and inf in tensor 46 | 47 | Args: 48 | tensor (torch.Tensor): input tensor 49 | 50 | Returns: 51 | torch.Tensor: output tensor 52 | """ 53 | 54 | tensor = torch.where(torch.isnan(tensor), torch.zeros_like(tensor), tensor) 55 | tensor = torch.where(torch.isinf(tensor), torch.zeros_like(tensor), tensor) 56 | return tensor 57 | -------------------------------------------------------------------------------- /basicts/utils/serialization.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from .adjacent_matrix_norm import calculate_scaled_laplacian, calculate_symmetric_normalized_laplacian, calculate_symmetric_message_passing_adj, calculate_transition_matrix 7 | 8 | 9 | def load_pkl(pickle_file: str) -> object: 10 | """Load pickle data. 11 | 12 | Args: 13 | pickle_file (str): file path 14 | 15 | Returns: 16 | object: loaded objected 17 | """ 18 | 19 | try: 20 | with open(pickle_file, "rb") as f: 21 | pickle_data = pickle.load(f) 22 | except UnicodeDecodeError: 23 | with open(pickle_file, "rb") as f: 24 | pickle_data = pickle.load(f, encoding="latin1") 25 | except Exception as e: 26 | print("Unable to load data ", pickle_file, ":", e) 27 | raise 28 | return pickle_data 29 | 30 | 31 | def dump_pkl(obj: object, file_path: str): 32 | """Dumplicate pickle data. 33 | 34 | Args: 35 | obj (object): object 36 | file_path (str): file path 37 | """ 38 | 39 | with open(file_path, "wb") as f: 40 | pickle.dump(obj, f) 41 | 42 | 43 | def load_adj(file_path: str, adj_type: str): 44 | """load adjacency matrix. 45 | 46 | Args: 47 | file_path (str): file path 48 | adj_type (str): adjacency matrix type 49 | 50 | Returns: 51 | list of numpy.matrix: list of preproceesed adjacency matrices 52 | np.ndarray: raw adjacency matrix 53 | """ 54 | 55 | try: 56 | # METR and PEMS_BAY 57 | _, _, adj_mx = load_pkl(file_path) 58 | except ValueError: 59 | # PEMS04 60 | adj_mx = load_pkl(file_path) 61 | if adj_type == "scalap": 62 | adj = [calculate_scaled_laplacian(adj_mx).astype(np.float32).todense()] 63 | elif adj_type == "normlap": 64 | adj = [calculate_symmetric_normalized_laplacian( 65 | adj_mx).astype(np.float32).todense()] 66 | elif adj_type == "symnadj": 67 | adj = [calculate_symmetric_message_passing_adj( 68 | adj_mx).astype(np.float32).todense()] 69 | elif adj_type == "transition": 70 | adj = [calculate_transition_matrix(adj_mx).T] 71 | elif adj_type == "doubletransition": 72 | adj = [calculate_transition_matrix(adj_mx).T, calculate_transition_matrix(adj_mx.T).T] 73 | elif adj_type == "identity": 74 | adj = [np.diag(np.ones(adj_mx.shape[0])).astype(np.float32)] 75 | elif adj_type == "original": 76 | adj = [adj_mx] 77 | else: 78 | error = 0 79 | assert error, "adj type not defined" 80 | return adj, adj_mx 81 | 82 | 83 | def load_node2vec_emb(file_path: str) -> torch.Tensor: 84 | """load node2vec embedding 85 | 86 | Args: 87 | file_path (str): file path 88 | 89 | Returns: 90 | torch.Tensor: node2vec embedding 91 | """ 92 | 93 | # spatial embedding 94 | with open(file_path, mode="r") as f: 95 | lines = f.readlines() 96 | temp = lines[0].split(" ") 97 | num_vertex, dims = int(temp[0]), int(temp[1]) 98 | spatial_embeddings = torch.zeros((num_vertex, dims), dtype=torch.float32) 99 | for line in lines[1:]: 100 | temp = line.split(" ") 101 | index = int(temp[0]) 102 | spatial_embeddings[index] = torch.Tensor([float(ch) for ch in temp[1:]]) 103 | return spatial_embeddings 104 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Data Preparation 3 | 4 | ## Download Raw Data 5 | 6 | You can download all the raw datasets at [Google Drive](https://drive.google.com/file/d/1PY7IZ3SchpyXfNIXs71A2GEV29W5QCv2/view?usp=sharing) or [Baidu Yun](https://pan.baidu.com/s/1CXLxeHxHIMWLy3IKGFUq8g?pwd=blf8), and unzip them to `datasets/raw_data/`. 7 | 8 | ## Pre-process Data 9 | 10 | You can pre-process all data via: 11 | 12 | ```bash 13 | cd /path/to/your/project 14 | bash scripts/data_preparation/all.sh 15 | ``` 16 | 17 | Then the `dataset` directory will look like this: 18 | 19 | ```text 20 | datasets 21 | ├─METR-LA 22 | ├─METR-BAY 23 | ├─PEMS04 24 | ├─raw_data 25 | | ├─PEMS04 26 | | ├─PEMS-BAY 27 | | ├─METR-LA 28 | ├─README.md 29 | ``` 30 | -------------------------------------------------------------------------------- /figure/Inspecting.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GestaltCogTeam/STEP/566e2738da2d83f055718d8edb609ad8dc325204/figure/Inspecting.jpg -------------------------------------------------------------------------------- /figure/MainResults.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GestaltCogTeam/STEP/566e2738da2d83f055718d8edb609ad8dc325204/figure/MainResults.png -------------------------------------------------------------------------------- /figure/STEP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GestaltCogTeam/STEP/566e2738da2d83f055718d8edb609ad8dc325204/figure/STEP.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | cachetools==4.2.4 3 | certifi==2021.10.8 4 | charset-normalizer==2.0.12 5 | easy-torch==1.2.10 6 | easydict==1.9 7 | google-auth==2.6.6 8 | google-auth-oauthlib==0.4.6 9 | grpcio==1.46.1 10 | idna==3.3 11 | importlib-metadata==4.11.3 12 | setuptools==59.5.0 13 | joblib==1.1.0 14 | Markdown==3.3.7 15 | numpy==1.22.3 16 | oauthlib==3.2.0 17 | Pillow==9.1.0 18 | protobuf==3.20.1 19 | pyasn1==0.4.8 20 | pyasn1-modules==0.2.8 21 | requests==2.27.1 22 | requests-oauthlib==1.3.1 23 | rsa==4.8 24 | scikit-learn==1.1.1 25 | scipy==1.8.0 26 | setproctitle==1.2.3 27 | six==1.16.0 28 | tensorboard==2.9.0 29 | tensorboard-data-server==0.6.1 30 | tensorboard-plugin-wit==1.8.1 31 | threadpoolctl==3.1.0 32 | tqdm==4.64.0 33 | typing_extensions==4.2.0 34 | urllib3==1.26.9 35 | Werkzeug==2.1.2 36 | zipp==3.8.0 37 | timm==0.6.7 -------------------------------------------------------------------------------- /scripts/data_preparation/PEMS03/generate_adj_mx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import pickle 4 | 5 | import numpy as np 6 | 7 | 8 | def get_adjacency_matrix(distance_df_filename: str, num_of_vertices: int, id_filename: str = None) -> tuple: 9 | """Generate adjacency matrix. 10 | 11 | Args: 12 | distance_df_filename (str): path of the csv file contains edges information 13 | num_of_vertices (int): number of vertices 14 | id_filename (str, optional): id filename. Defaults to None. 15 | 16 | Returns: 17 | tuple: two adjacency matrix. 18 | np.array: connectivity-based adjacency matrix A (A[i, j]=0 or A[i, j]=1) 19 | np.array: distance-based adjacency matrix A 20 | """ 21 | 22 | if "npy" in distance_df_filename: 23 | adj_mx = np.load(distance_df_filename) 24 | return adj_mx, None 25 | else: 26 | adjacency_matrix_connectivity = np.zeros((int(num_of_vertices), int( 27 | num_of_vertices)), dtype=np.float32) 28 | adjacency_matrix_distance = np.zeros((int(num_of_vertices), int(num_of_vertices)), 29 | dtype=np.float32) 30 | if id_filename: 31 | # the id in the distance file does not start from 0, so it needs to be remapped 32 | with open(id_filename, "r") as f: 33 | id_dict = {int(i): idx for idx, i in enumerate( 34 | f.read().strip().split("\n"))} # map node idx to 0-based index (start from 0) 35 | with open(distance_df_filename, "r") as f: 36 | f.readline() # omit the first line 37 | reader = csv.reader(f) 38 | for row in reader: 39 | if len(row) != 3: 40 | continue 41 | i, j, distance = int(row[0]), int(row[1]), float(row[2]) 42 | adjacency_matrix_connectivity[id_dict[i], id_dict[j]] = 1 43 | adjacency_matrix_connectivity[id_dict[j], id_dict[i]] = 1 44 | adjacency_matrix_distance[id_dict[i], 45 | id_dict[j]] = distance 46 | adjacency_matrix_distance[id_dict[j], 47 | id_dict[i]] = distance 48 | return adjacency_matrix_connectivity, adjacency_matrix_distance 49 | else: 50 | # ids in distance file start from 0 51 | with open(distance_df_filename, "r") as f: 52 | f.readline() 53 | reader = csv.reader(f) 54 | for row in reader: 55 | if len(row) != 3: 56 | continue 57 | i, j, distance = int(row[0]), int(row[1]), float(row[2]) 58 | adjacency_matrix_connectivity[i, j] = 1 59 | adjacency_matrix_connectivity[j, i] = 1 60 | adjacency_matrix_distance[i, j] = distance 61 | adjacency_matrix_distance[j, i] = distance 62 | return adjacency_matrix_connectivity, adjacency_matrix_distance 63 | 64 | 65 | def generate_adj_pems03(): 66 | distance_df_filename, num_of_vertices = "datasets/raw_data/PEMS03/PEMS03.csv", 358 67 | if os.path.exists(distance_df_filename.split(".", maxsplit=1)[0] + ".txt"): 68 | id_filename = distance_df_filename.split(".", maxsplit=1)[0] + ".txt" 69 | else: 70 | id_filename = None 71 | adj_mx, distance_mx = get_adjacency_matrix( 72 | distance_df_filename, num_of_vertices, id_filename=id_filename) 73 | # the self loop is missing 74 | add_self_loop = False 75 | if add_self_loop: 76 | print("adding self loop to adjacency matrices.") 77 | adj_mx = adj_mx + np.identity(adj_mx.shape[0]) 78 | distance_mx = distance_mx + np.identity(distance_mx.shape[0]) 79 | else: 80 | print("kindly note that there is no self loop in adjacency matrices.") 81 | with open("datasets/raw_data/PEMS03/adj_PEMS03.pkl", "wb") as f: 82 | pickle.dump(adj_mx, f) 83 | with open("datasets/raw_data/PEMS03/adj_PEMS03_distance.pkl", "wb") as f: 84 | pickle.dump(distance_mx, f) 85 | -------------------------------------------------------------------------------- /scripts/data_preparation/PEMS04/generate_adj_mx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import pickle 4 | 5 | import numpy as np 6 | 7 | 8 | def get_adjacency_matrix(distance_df_filename: str, num_of_vertices: int, id_filename: str = None) -> tuple: 9 | """Generate adjacency matrix. 10 | 11 | Args: 12 | distance_df_filename (str): path of the csv file contains edges information 13 | num_of_vertices (int): number of vertices 14 | id_filename (str, optional): id filename. Defaults to None. 15 | 16 | Returns: 17 | tuple: two adjacency matrix. 18 | np.array: connectivity-based adjacency matrix A (A[i, j]=0 or A[i, j]=1) 19 | np.array: distance-based adjacency matrix A 20 | """ 21 | 22 | if "npy" in distance_df_filename: 23 | adj_mx = np.load(distance_df_filename) 24 | return adj_mx, None 25 | else: 26 | adjacency_matrix_connectivity = np.zeros((int(num_of_vertices), int( 27 | num_of_vertices)), dtype=np.float32) 28 | adjacency_matrix_distance = np.zeros((int(num_of_vertices), int(num_of_vertices)), 29 | dtype=np.float32) 30 | if id_filename: 31 | # the id in the distance file does not start from 0, so it needs to be remapped 32 | with open(id_filename, "r") as f: 33 | id_dict = {int(i): idx for idx, i in enumerate( 34 | f.read().strip().split("\n"))} # map node idx to 0-based index (start from 0) 35 | with open(distance_df_filename, "r") as f: 36 | f.readline() # omit the first line 37 | reader = csv.reader(f) 38 | for row in reader: 39 | if len(row) != 3: 40 | continue 41 | i, j, distance = int(row[0]), int(row[1]), float(row[2]) 42 | adjacency_matrix_connectivity[id_dict[i], id_dict[j]] = 1 43 | adjacency_matrix_connectivity[id_dict[j], id_dict[i]] = 1 44 | adjacency_matrix_distance[id_dict[i], 45 | id_dict[j]] = distance 46 | adjacency_matrix_distance[id_dict[j], 47 | id_dict[i]] = distance 48 | return adjacency_matrix_connectivity, adjacency_matrix_distance 49 | else: 50 | # ids in distance file start from 0 51 | with open(distance_df_filename, "r") as f: 52 | f.readline() 53 | reader = csv.reader(f) 54 | for row in reader: 55 | if len(row) != 3: 56 | continue 57 | i, j, distance = int(row[0]), int(row[1]), float(row[2]) 58 | adjacency_matrix_connectivity[i, j] = 1 59 | adjacency_matrix_connectivity[j, i] = 1 60 | adjacency_matrix_distance[i, j] = distance 61 | adjacency_matrix_distance[j, i] = distance 62 | return adjacency_matrix_connectivity, adjacency_matrix_distance 63 | 64 | 65 | def generate_adj_pems04(): 66 | distance_df_filename, num_of_vertices = "datasets/raw_data/PEMS04/PEMS04.csv", 307 67 | if os.path.exists(distance_df_filename.split(".", maxsplit=1)[0] + ".txt"): 68 | id_filename = distance_df_filename.split(".", maxsplit=1)[0] + ".txt" 69 | else: 70 | id_filename = None 71 | adj_mx, distance_mx = get_adjacency_matrix( 72 | distance_df_filename, num_of_vertices, id_filename=id_filename) 73 | # the self loop is missing 74 | add_self_loop = False 75 | if add_self_loop: 76 | print("adding self loop to adjacency matrices.") 77 | adj_mx = adj_mx + np.identity(adj_mx.shape[0]) 78 | distance_mx = distance_mx + np.identity(distance_mx.shape[0]) 79 | else: 80 | print("kindly note that there is no self loop in adjacency matrices.") 81 | with open("datasets/raw_data/PEMS04/adj_PEMS04.pkl", "wb") as f: 82 | pickle.dump(adj_mx, f) 83 | with open("datasets/raw_data/PEMS04/adj_PEMS04_distance.pkl", "wb") as f: 84 | pickle.dump(distance_mx, f) 85 | -------------------------------------------------------------------------------- /scripts/data_preparation/PEMS07/generate_adj_mx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import pickle 4 | 5 | import numpy as np 6 | 7 | 8 | def get_adjacency_matrix(distance_df_filename: str, num_of_vertices: int, id_filename: str = None) -> tuple: 9 | """Generate adjacency matrix. 10 | 11 | Args: 12 | distance_df_filename (str): path of the csv file contains edges information 13 | num_of_vertices (int): number of vertices 14 | id_filename (str, optional): id filename. Defaults to None. 15 | 16 | Returns: 17 | tuple: two adjacency matrix. 18 | np.array: connectivity-based adjacency matrix A (A[i, j]=0 or A[i, j]=1) 19 | np.array: distance-based adjacency matrix A 20 | """ 21 | 22 | if "npy" in distance_df_filename: 23 | adj_mx = np.load(distance_df_filename) 24 | return adj_mx, None 25 | else: 26 | adjacency_matrix_connectivity = np.zeros((int(num_of_vertices), int( 27 | num_of_vertices)), dtype=np.float32) 28 | adjacency_matrix_distance = np.zeros((int(num_of_vertices), int(num_of_vertices)), 29 | dtype=np.float32) 30 | if id_filename: 31 | # the id in the distance file does not start from 0, so it needs to be remapped 32 | with open(id_filename, "r") as f: 33 | id_dict = {int(i): idx for idx, i in enumerate( 34 | f.read().strip().split("\n"))} # map node idx to 0-based index (start from 0) 35 | with open(distance_df_filename, "r") as f: 36 | f.readline() # omit the first line 37 | reader = csv.reader(f) 38 | for row in reader: 39 | if len(row) != 3: 40 | continue 41 | i, j, distance = int(row[0]), int(row[1]), float(row[2]) 42 | adjacency_matrix_connectivity[id_dict[i], id_dict[j]] = 1 43 | adjacency_matrix_connectivity[id_dict[j], id_dict[i]] = 1 44 | adjacency_matrix_distance[id_dict[i], 45 | id_dict[j]] = distance 46 | adjacency_matrix_distance[id_dict[j], 47 | id_dict[i]] = distance 48 | return adjacency_matrix_connectivity, adjacency_matrix_distance 49 | else: 50 | # ids in distance file start from 0 51 | with open(distance_df_filename, "r") as f: 52 | f.readline() 53 | reader = csv.reader(f) 54 | for row in reader: 55 | if len(row) != 3: 56 | continue 57 | i, j, distance = int(row[0]), int(row[1]), float(row[2]) 58 | adjacency_matrix_connectivity[i, j] = 1 59 | adjacency_matrix_connectivity[j, i] = 1 60 | adjacency_matrix_distance[i, j] = distance 61 | adjacency_matrix_distance[j, i] = distance 62 | return adjacency_matrix_connectivity, adjacency_matrix_distance 63 | 64 | 65 | def generate_adj_pems07(): 66 | distance_df_filename, num_of_vertices = "datasets/raw_data/PEMS07/PEMS07.csv", 883 67 | if os.path.exists(distance_df_filename.split(".", maxsplit=1)[0] + ".txt"): 68 | id_filename = distance_df_filename.split(".", maxsplit=1)[0] + ".txt" 69 | else: 70 | id_filename = None 71 | adj_mx, distance_mx = get_adjacency_matrix( 72 | distance_df_filename, num_of_vertices, id_filename=id_filename) 73 | # the self loop is missing 74 | add_self_loop = False 75 | if add_self_loop: 76 | print("adding self loop to adjacency matrices.") 77 | adj_mx = adj_mx + np.identity(adj_mx.shape[0]) 78 | distance_mx = distance_mx + np.identity(distance_mx.shape[0]) 79 | else: 80 | print("kindly note that there is no self loop in adjacency matrices.") 81 | with open("datasets/raw_data/PEMS07/adj_PEMS07.pkl", "wb") as f: 82 | pickle.dump(adj_mx, f) 83 | with open("datasets/raw_data/PEMS07/adj_PEMS07_distance.pkl", "wb") as f: 84 | pickle.dump(distance_mx, f) 85 | -------------------------------------------------------------------------------- /scripts/data_preparation/PEMS08/generate_adj_mx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import pickle 4 | 5 | import numpy as np 6 | 7 | 8 | def get_adjacency_matrix(distance_df_filename: str, num_of_vertices: int, id_filename: str = None) -> tuple: 9 | """Generate adjacency matrix. 10 | 11 | Args: 12 | distance_df_filename (str): path of the csv file contains edges information 13 | num_of_vertices (int): number of vertices 14 | id_filename (str, optional): id filename. Defaults to None. 15 | 16 | Returns: 17 | tuple: two adjacency matrix. 18 | np.array: connectivity-based adjacency matrix A (A[i, j]=0 or A[i, j]=1) 19 | np.array: distance-based adjacency matrix A 20 | """ 21 | 22 | if "npy" in distance_df_filename: 23 | adj_mx = np.load(distance_df_filename) 24 | return adj_mx, None 25 | else: 26 | adjacency_matrix_connectivity = np.zeros((int(num_of_vertices), int( 27 | num_of_vertices)), dtype=np.float32) 28 | adjacency_matrix_distance = np.zeros((int(num_of_vertices), int(num_of_vertices)), 29 | dtype=np.float32) 30 | if id_filename: 31 | # the id in the distance file does not start from 0, so it needs to be remapped 32 | with open(id_filename, "r") as f: 33 | id_dict = {int(i): idx for idx, i in enumerate( 34 | f.read().strip().split("\n"))} # map node idx to 0-based index (start from 0) 35 | with open(distance_df_filename, "r") as f: 36 | f.readline() # omit the first line 37 | reader = csv.reader(f) 38 | for row in reader: 39 | if len(row) != 3: 40 | continue 41 | i, j, distance = int(row[0]), int(row[1]), float(row[2]) 42 | adjacency_matrix_connectivity[id_dict[i], id_dict[j]] = 1 43 | adjacency_matrix_connectivity[id_dict[j], id_dict[i]] = 1 44 | adjacency_matrix_distance[id_dict[i], 45 | id_dict[j]] = distance 46 | adjacency_matrix_distance[id_dict[j], 47 | id_dict[i]] = distance 48 | return adjacency_matrix_connectivity, adjacency_matrix_distance 49 | else: 50 | # ids in distance file start from 0 51 | with open(distance_df_filename, "r") as f: 52 | f.readline() 53 | reader = csv.reader(f) 54 | for row in reader: 55 | if len(row) != 3: 56 | continue 57 | i, j, distance = int(row[0]), int(row[1]), float(row[2]) 58 | adjacency_matrix_connectivity[i, j] = 1 59 | adjacency_matrix_connectivity[j, i] = 1 60 | adjacency_matrix_distance[i, j] = distance 61 | adjacency_matrix_distance[j, i] = distance 62 | return adjacency_matrix_connectivity, adjacency_matrix_distance 63 | 64 | 65 | def generate_adj_pems08(): 66 | distance_df_filename, num_of_vertices = "datasets/raw_data/PEMS08/PEMS08.csv", 170 67 | if os.path.exists(distance_df_filename.split(".", maxsplit=1)[0] + ".txt"): 68 | id_filename = distance_df_filename.split(".", maxsplit=1)[0] + ".txt" 69 | else: 70 | id_filename = None 71 | adj_mx, distance_mx = get_adjacency_matrix( 72 | distance_df_filename, num_of_vertices, id_filename=id_filename) 73 | # the self loop is missing 74 | add_self_loop = False 75 | if add_self_loop: 76 | print("adding self loop to adjacency matrices.") 77 | adj_mx = adj_mx + np.identity(adj_mx.shape[0]) 78 | distance_mx = distance_mx + np.identity(distance_mx.shape[0]) 79 | else: 80 | print("kindly note that there is no self loop in adjacency matrices.") 81 | with open("datasets/raw_data/PEMS08/adj_PEMS08.pkl", "wb") as f: 82 | pickle.dump(adj_mx, f) 83 | with open("datasets/raw_data/PEMS08/adj_PEMS08_distance.pkl", "wb") as f: 84 | pickle.dump(distance_mx, f) 85 | -------------------------------------------------------------------------------- /scripts/data_preparation/all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python scripts/data_preparation/METR-LA/generate_training_data.py --history_seq_len 12 3 | python scripts/data_preparation/METR-LA/generate_training_data.py --history_seq_len 2016 4 | python scripts/data_preparation/PEMS-BAY/generate_training_data.py --history_seq_len 12 5 | python scripts/data_preparation/PEMS-BAY/generate_training_data.py --history_seq_len 2016 6 | python scripts/data_preparation/PEMS04/generate_training_data.py --history_seq_len 12 7 | python scripts/data_preparation/PEMS04/generate_training_data.py --history_seq_len 4032 8 | 9 | -------------------------------------------------------------------------------- /step/STEP_PEMS-BAY.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | # TODO: remove it when basicts can be installed by pip 6 | sys.path.append(os.path.abspath(__file__ + "/../../..")) 7 | import torch 8 | from easydict import EasyDict 9 | from basicts.utils.serialization import load_adj 10 | 11 | from .step_arch import STEP 12 | from .step_runner import STEPRunner 13 | from .step_loss import step_loss 14 | from .step_data import ForecastingDataset 15 | 16 | 17 | CFG = EasyDict() 18 | 19 | # ================= general ================= # 20 | CFG.DESCRIPTION = "STEP(PEMS-BAY) configuration" 21 | CFG.RUNNER = STEPRunner 22 | CFG.DATASET_CLS = ForecastingDataset 23 | CFG.DATASET_NAME = "PEMS-BAY" 24 | CFG.DATASET_TYPE = "Traffic speed" 25 | CFG.DATASET_INPUT_LEN = 12 26 | CFG.DATASET_OUTPUT_LEN = 12 27 | CFG.DATASET_ARGS = { 28 | "seq_len": 288 * 7 29 | } 30 | CFG.GPU_NUM = 2 31 | 32 | # ================= environment ================= # 33 | CFG.ENV = EasyDict() 34 | CFG.ENV.SEED = 0 35 | CFG.ENV.CUDNN = EasyDict() 36 | CFG.ENV.CUDNN.ENABLED = True 37 | 38 | # ================= model ================= # 39 | CFG.MODEL = EasyDict() 40 | CFG.MODEL.NAME = "STEP" 41 | CFG.MODEL.ARCH = STEP 42 | CFG.MODEL.PARAM = { 43 | "dataset_name": CFG.DATASET_NAME, 44 | "pre_trained_tsformer_path": "tsformer_ckpt/TSFormer_PEMS-BAY.pt", 45 | "tsformer_args": { 46 | "patch_size":12, 47 | "in_channel":1, 48 | "embed_dim":96, 49 | "num_heads":4, 50 | "mlp_ratio":4, 51 | "dropout":0.1, 52 | "num_token":288 * 7 / 12, 53 | "mask_ratio":0.75, 54 | "encoder_depth":4, 55 | "decoder_depth":1, 56 | "mode":"forecasting" 57 | }, 58 | "backend_args": { 59 | "num_nodes" : 325, 60 | "support_len" : 2, 61 | "dropout" : 0.3, 62 | "gcn_bool" : True, 63 | "addaptadj" : True, 64 | "aptinit" : None, 65 | "in_dim" : 2, 66 | "out_dim" : 12, 67 | "residual_channels" : 32, 68 | "dilation_channels" : 32, 69 | "skip_channels" : 256, 70 | "end_channels" : 512, 71 | "kernel_size" : 2, 72 | "blocks" : 4, 73 | "layers" : 2 74 | }, 75 | "dgl_args": { 76 | "dataset_name": CFG.DATASET_NAME, 77 | "k": 10, 78 | "input_seq_len": CFG.DATASET_INPUT_LEN, 79 | "output_seq_len": CFG.DATASET_OUTPUT_LEN 80 | } 81 | } 82 | CFG.MODEL.FORWARD_FEATURES = [0, 1, 2] 83 | CFG.MODEL.TARGET_FEATURES = [0] 84 | CFG.MODEL.DDP_FIND_UNUSED_PARAMETERS = True 85 | 86 | # ================= optim ================= # 87 | CFG.TRAIN = EasyDict() 88 | CFG.TRAIN.LOSS = step_loss 89 | CFG.TRAIN.OPTIM = EasyDict() 90 | CFG.TRAIN.OPTIM.TYPE = "Adam" 91 | CFG.TRAIN.OPTIM.PARAM= { 92 | "lr":0.001, 93 | "weight_decay":1.0e-5, 94 | "eps":1.0e-8, 95 | } 96 | CFG.TRAIN.LR_SCHEDULER = EasyDict() 97 | CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" 98 | CFG.TRAIN.LR_SCHEDULER.PARAM= { 99 | "milestones":[1, 18, 36, 54, 72], 100 | "gamma":0.5 101 | } 102 | 103 | # ================= train ================= # 104 | CFG.TRAIN.CLIP_GRAD_PARAM = { 105 | "max_norm": 3.0 106 | } 107 | CFG.TRAIN.NUM_EPOCHS = 100 108 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 109 | "checkpoints", 110 | "_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 111 | ) 112 | # train data 113 | CFG.TRAIN.DATA = EasyDict() 114 | CFG.TRAIN.NULL_VAL = 0.0 115 | # read data 116 | CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME 117 | # dataloader args, optional 118 | CFG.TRAIN.DATA.BATCH_SIZE = 32 119 | CFG.TRAIN.DATA.PREFETCH = False 120 | CFG.TRAIN.DATA.SHUFFLE = True 121 | CFG.TRAIN.DATA.NUM_WORKERS = 2 122 | CFG.TRAIN.DATA.PIN_MEMORY = True 123 | # curriculum learning 124 | CFG.TRAIN.CL = EasyDict() 125 | CFG.TRAIN.CL.WARM_EPOCHS = 30 126 | CFG.TRAIN.CL.CL_EPOCHS = 3 127 | CFG.TRAIN.CL.PREDICTION_LENGTH = 12 128 | 129 | # ================= validate ================= # 130 | CFG.VAL = EasyDict() 131 | CFG.VAL.INTERVAL = 1 132 | # validating data 133 | CFG.VAL.DATA = EasyDict() 134 | # read data 135 | CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME 136 | # dataloader args, optional 137 | CFG.VAL.DATA.BATCH_SIZE = 32 138 | CFG.VAL.DATA.PREFETCH = False 139 | CFG.VAL.DATA.SHUFFLE = False 140 | CFG.VAL.DATA.NUM_WORKERS = 2 141 | CFG.VAL.DATA.PIN_MEMORY = True 142 | 143 | # ================= test ================= # 144 | CFG.TEST = EasyDict() 145 | CFG.TEST.INTERVAL = 1 146 | # evluation 147 | # test data 148 | CFG.TEST.DATA = EasyDict() 149 | # read data 150 | CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME 151 | # dataloader args, optional 152 | CFG.TEST.DATA.BATCH_SIZE = 32 153 | CFG.TEST.DATA.PREFETCH = False 154 | CFG.TEST.DATA.SHUFFLE = False 155 | CFG.TEST.DATA.NUM_WORKERS = 2 156 | CFG.TEST.DATA.PIN_MEMORY = True 157 | -------------------------------------------------------------------------------- /step/STEP_PEMS03.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | # TODO: remove it when basicts can be installed by pip 6 | sys.path.append(os.path.abspath(__file__ + "/../../..")) 7 | import torch 8 | from easydict import EasyDict 9 | from basicts.utils.serialization import load_adj 10 | 11 | from .step_arch import STEP 12 | from .step_runner import STEPRunner 13 | from .step_loss import step_loss 14 | from .step_data import ForecastingDataset 15 | 16 | 17 | CFG = EasyDict() 18 | 19 | # ================= general ================= # 20 | CFG.DESCRIPTION = "STEP(PEMS03) configuration" 21 | CFG.RUNNER = STEPRunner 22 | CFG.DATASET_CLS = ForecastingDataset 23 | CFG.DATASET_NAME = "PEMS03" 24 | CFG.DATASET_TYPE = "Traffic flow" 25 | CFG.DATASET_INPUT_LEN = 12 26 | CFG.DATASET_OUTPUT_LEN = 12 27 | CFG.DATASET_ARGS = { 28 | "seq_len": 288 * 7 * 2 29 | } 30 | CFG.GPU_NUM = 2 31 | 32 | # ================= environment ================= # 33 | CFG.ENV = EasyDict() 34 | CFG.ENV.SEED = 0 35 | CFG.ENV.CUDNN = EasyDict() 36 | CFG.ENV.CUDNN.ENABLED = True 37 | 38 | # ================= model ================= # 39 | CFG.MODEL = EasyDict() 40 | CFG.MODEL.NAME = "STEP" 41 | CFG.MODEL.ARCH = STEP 42 | CFG.MODEL.PARAM = { 43 | "dataset_name": CFG.DATASET_NAME, 44 | "pre_trained_tsformer_path": "tsformer_ckpt/TSFormer_PEMS03.pt", 45 | "tsformer_args": { 46 | "patch_size":12, 47 | "in_channel":1, 48 | "embed_dim":96, 49 | "num_heads":4, 50 | "mlp_ratio":4, 51 | "dropout":0.1, 52 | "num_token":288 * 7 * 2 / 12, 53 | "mask_ratio":0.75, 54 | "encoder_depth":4, 55 | "decoder_depth":1, 56 | "mode":"forecasting" 57 | }, 58 | "backend_args": { 59 | "num_nodes" : 358, 60 | "support_len" : 2, 61 | "dropout" : 0.3, 62 | "gcn_bool" : True, 63 | "addaptadj" : True, 64 | "aptinit" : None, 65 | "in_dim" : 2, 66 | "out_dim" : 12, 67 | "residual_channels" : 32, 68 | "dilation_channels" : 32, 69 | "skip_channels" : 256, 70 | "end_channels" : 512, 71 | "kernel_size" : 2, 72 | "blocks" : 4, 73 | "layers" : 2 74 | }, 75 | "dgl_args": { 76 | "dataset_name": CFG.DATASET_NAME, 77 | "k": 10, 78 | "input_seq_len": CFG.DATASET_INPUT_LEN, 79 | "output_seq_len": CFG.DATASET_OUTPUT_LEN 80 | } 81 | } 82 | CFG.MODEL.FORWARD_FEATURES = [0, 1, 2] 83 | CFG.MODEL.TARGET_FEATURES = [0] 84 | CFG.MODEL.DDP_FIND_UNUSED_PARAMETERS = True 85 | 86 | # ================= optim ================= # 87 | CFG.TRAIN = EasyDict() 88 | CFG.TRAIN.LOSS = step_loss 89 | CFG.TRAIN.OPTIM = EasyDict() 90 | CFG.TRAIN.OPTIM.TYPE = "Adam" 91 | CFG.TRAIN.OPTIM.PARAM= { 92 | "lr":0.002, 93 | "weight_decay":1.0e-5, 94 | "eps":1.0e-8, 95 | } 96 | CFG.TRAIN.LR_SCHEDULER = EasyDict() 97 | CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" 98 | CFG.TRAIN.LR_SCHEDULER.PARAM= { 99 | "milestones":[1, 18, 36, 54, 72], 100 | "gamma":0.5 101 | } 102 | 103 | # ================= train ================= # 104 | CFG.TRAIN.CLIP_GRAD_PARAM = { 105 | "max_norm": 3.0 106 | } 107 | CFG.TRAIN.NUM_EPOCHS = 100 108 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 109 | "checkpoints", 110 | "_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 111 | ) 112 | # train data 113 | CFG.TRAIN.DATA = EasyDict() 114 | CFG.TRAIN.NULL_VAL = 0.0 115 | # read data 116 | CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME 117 | # dataloader args, optional 118 | CFG.TRAIN.DATA.BATCH_SIZE = 4 119 | CFG.TRAIN.DATA.PREFETCH = False 120 | CFG.TRAIN.DATA.SHUFFLE = True 121 | CFG.TRAIN.DATA.NUM_WORKERS = 2 122 | CFG.TRAIN.DATA.PIN_MEMORY = True 123 | 124 | # ================= validate ================= # 125 | CFG.VAL = EasyDict() 126 | CFG.VAL.INTERVAL = 1 127 | # validating data 128 | CFG.VAL.DATA = EasyDict() 129 | # read data 130 | CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME 131 | # dataloader args, optional 132 | CFG.VAL.DATA.BATCH_SIZE = 4 133 | CFG.VAL.DATA.PREFETCH = False 134 | CFG.VAL.DATA.SHUFFLE = False 135 | CFG.VAL.DATA.NUM_WORKERS = 2 136 | CFG.VAL.DATA.PIN_MEMORY = True 137 | 138 | # ================= test ================= # 139 | CFG.TEST = EasyDict() 140 | CFG.TEST.INTERVAL = 1 141 | # evluation 142 | # test data 143 | CFG.TEST.DATA = EasyDict() 144 | # read data 145 | CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME 146 | # dataloader args, optional 147 | CFG.TEST.DATA.BATCH_SIZE = 4 148 | CFG.TEST.DATA.PREFETCH = False 149 | CFG.TEST.DATA.SHUFFLE = False 150 | CFG.TEST.DATA.NUM_WORKERS = 2 151 | CFG.TEST.DATA.PIN_MEMORY = True 152 | -------------------------------------------------------------------------------- /step/STEP_PEMS04.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | # TODO: remove it when basicts can be installed by pip 6 | sys.path.append(os.path.abspath(__file__ + "/../../..")) 7 | import torch 8 | from easydict import EasyDict 9 | from basicts.utils.serialization import load_adj 10 | 11 | from .step_arch import STEP 12 | from .step_runner import STEPRunner 13 | from .step_loss import step_loss 14 | from .step_data import ForecastingDataset 15 | 16 | 17 | CFG = EasyDict() 18 | 19 | # ================= general ================= # 20 | CFG.DESCRIPTION = "STEP(PEMS04) configuration" 21 | CFG.RUNNER = STEPRunner 22 | CFG.DATASET_CLS = ForecastingDataset 23 | CFG.DATASET_NAME = "PEMS04" 24 | CFG.DATASET_TYPE = "Traffic flow" 25 | CFG.DATASET_INPUT_LEN = 12 26 | CFG.DATASET_OUTPUT_LEN = 12 27 | CFG.DATASET_ARGS = { 28 | "seq_len": 288 * 7 * 2 29 | } 30 | CFG.GPU_NUM = 2 31 | 32 | # ================= environment ================= # 33 | CFG.ENV = EasyDict() 34 | CFG.ENV.SEED = 0 35 | CFG.ENV.CUDNN = EasyDict() 36 | CFG.ENV.CUDNN.ENABLED = True 37 | 38 | # ================= model ================= # 39 | CFG.MODEL = EasyDict() 40 | CFG.MODEL.NAME = "STEP" 41 | CFG.MODEL.ARCH = STEP 42 | CFG.MODEL.PARAM = { 43 | "dataset_name": CFG.DATASET_NAME, 44 | "pre_trained_tsformer_path": "tsformer_ckpt/TSFormer_PEMS04.pt", 45 | "tsformer_args": { 46 | "patch_size":12, 47 | "in_channel":1, 48 | "embed_dim":96, 49 | "num_heads":4, 50 | "mlp_ratio":4, 51 | "dropout":0.1, 52 | "num_token":288 * 7 * 2 / 12, 53 | "mask_ratio":0.75, 54 | "encoder_depth":4, 55 | "decoder_depth":1, 56 | "mode":"forecasting" 57 | }, 58 | "backend_args": { 59 | "num_nodes" : 307, 60 | "support_len" : 2, 61 | "dropout" : 0.3, 62 | "gcn_bool" : True, 63 | "addaptadj" : True, 64 | "aptinit" : None, 65 | "in_dim" : 2, 66 | "out_dim" : 12, 67 | "residual_channels" : 32, 68 | "dilation_channels" : 32, 69 | "skip_channels" : 256, 70 | "end_channels" : 512, 71 | "kernel_size" : 2, 72 | "blocks" : 4, 73 | "layers" : 2 74 | }, 75 | "dgl_args": { 76 | "dataset_name": CFG.DATASET_NAME, 77 | "k": 10, 78 | "input_seq_len": CFG.DATASET_INPUT_LEN, 79 | "output_seq_len": CFG.DATASET_OUTPUT_LEN 80 | } 81 | } 82 | CFG.MODEL.FORWARD_FEATURES = [0, 1, 2] 83 | CFG.MODEL.TARGET_FEATURES = [0] 84 | CFG.MODEL.DDP_FIND_UNUSED_PARAMETERS = True 85 | 86 | # ================= optim ================= # 87 | CFG.TRAIN = EasyDict() 88 | CFG.TRAIN.LOSS = step_loss 89 | CFG.TRAIN.OPTIM = EasyDict() 90 | CFG.TRAIN.OPTIM.TYPE = "Adam" 91 | CFG.TRAIN.OPTIM.PARAM= { 92 | "lr":0.002, 93 | "weight_decay":1.0e-5, 94 | "eps":1.0e-8, 95 | } 96 | CFG.TRAIN.LR_SCHEDULER = EasyDict() 97 | CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" 98 | CFG.TRAIN.LR_SCHEDULER.PARAM= { 99 | "milestones":[1, 18, 36, 54, 72], 100 | "gamma":0.5 101 | } 102 | 103 | # ================= train ================= # 104 | CFG.TRAIN.CLIP_GRAD_PARAM = { 105 | "max_norm": 3.0 106 | } 107 | CFG.TRAIN.NUM_EPOCHS = 100 108 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 109 | "checkpoints", 110 | "_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 111 | ) 112 | # train data 113 | CFG.TRAIN.DATA = EasyDict() 114 | CFG.TRAIN.NULL_VAL = 0.0 115 | # read data 116 | CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME 117 | # dataloader args, optional 118 | CFG.TRAIN.DATA.BATCH_SIZE = 8 119 | CFG.TRAIN.DATA.PREFETCH = False 120 | CFG.TRAIN.DATA.SHUFFLE = True 121 | CFG.TRAIN.DATA.NUM_WORKERS = 2 122 | CFG.TRAIN.DATA.PIN_MEMORY = True 123 | 124 | # ================= validate ================= # 125 | CFG.VAL = EasyDict() 126 | CFG.VAL.INTERVAL = 1 127 | # validating data 128 | CFG.VAL.DATA = EasyDict() 129 | # read data 130 | CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME 131 | # dataloader args, optional 132 | CFG.VAL.DATA.BATCH_SIZE = 8 133 | CFG.VAL.DATA.PREFETCH = False 134 | CFG.VAL.DATA.SHUFFLE = False 135 | CFG.VAL.DATA.NUM_WORKERS = 2 136 | CFG.VAL.DATA.PIN_MEMORY = True 137 | 138 | # ================= test ================= # 139 | CFG.TEST = EasyDict() 140 | CFG.TEST.INTERVAL = 1 141 | # evluation 142 | # test data 143 | CFG.TEST.DATA = EasyDict() 144 | # read data 145 | CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME 146 | # dataloader args, optional 147 | CFG.TEST.DATA.BATCH_SIZE = 8 148 | CFG.TEST.DATA.PREFETCH = False 149 | CFG.TEST.DATA.SHUFFLE = False 150 | CFG.TEST.DATA.NUM_WORKERS = 2 151 | CFG.TEST.DATA.PIN_MEMORY = True 152 | -------------------------------------------------------------------------------- /step/STEP_PEMS07.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # TODO: remove it when basicts can be installed by pip 5 | sys.path.append(os.path.abspath(__file__ + "/../../..")) 6 | import torch 7 | from easydict import EasyDict 8 | from basicts.utils.serialization import load_adj 9 | 10 | from .step_arch import STEP 11 | from .step_runner import STEPRunner 12 | from .step_loss import step_loss 13 | from .step_data import ForecastingDataset 14 | 15 | 16 | CFG = EasyDict() 17 | 18 | # ================= general ================= # 19 | CFG.DESCRIPTION = "STEP(PEMS07) configuration" 20 | CFG.RUNNER = STEPRunner 21 | CFG.DATASET_CLS = ForecastingDataset 22 | CFG.DATASET_NAME = "PEMS07" 23 | CFG.DATASET_TYPE = "Traffic flow" 24 | CFG.DATASET_INPUT_LEN = 12 25 | CFG.DATASET_OUTPUT_LEN = 12 26 | CFG.DATASET_ARGS = { 27 | "seq_len": 288 * 7 28 | } 29 | CFG.GPU_NUM = 2 30 | 31 | # ================= environment ================= # 32 | CFG.ENV = EasyDict() 33 | CFG.ENV.SEED = 0 34 | CFG.ENV.CUDNN = EasyDict() 35 | CFG.ENV.CUDNN.ENABLED = True 36 | 37 | # ================= model ================= # 38 | CFG.MODEL = EasyDict() 39 | CFG.MODEL.NAME = "STEP" 40 | CFG.MODEL.ARCH = STEP 41 | CFG.MODEL.PARAM = { 42 | "dataset_name": CFG.DATASET_NAME, 43 | "pre_trained_tsformer_path": "tsformer_ckpt/TSFormer_PEMS07.pt", 44 | "tsformer_args": { 45 | "patch_size":12, 46 | "in_channel":1, 47 | "embed_dim":96, 48 | "num_heads":4, 49 | "mlp_ratio":4, 50 | "dropout":0.1, 51 | "num_token":288 * 7 / 12, 52 | "mask_ratio":0.75, 53 | "encoder_depth":4, 54 | "decoder_depth":1, 55 | "mode":"forecasting" 56 | }, 57 | "backend_args": { 58 | "num_nodes" : 883, 59 | "support_len" : 2, 60 | "dropout" : 0.3, 61 | "gcn_bool" : True, 62 | "addaptadj" : True, 63 | "aptinit" : None, 64 | "in_dim" : 2, 65 | "out_dim" : 12, 66 | "residual_channels" : 32, 67 | "dilation_channels" : 32, 68 | "skip_channels" : 256, 69 | "end_channels" : 512, 70 | "kernel_size" : 2, 71 | "blocks" : 4, 72 | "layers" : 2 73 | }, 74 | "dgl_args": { 75 | "dataset_name": CFG.DATASET_NAME, 76 | "k": 10, 77 | "input_seq_len": CFG.DATASET_INPUT_LEN, 78 | "output_seq_len": CFG.DATASET_OUTPUT_LEN 79 | } 80 | } 81 | CFG.MODEL.FORWARD_FEATURES = [0, 1, 2] 82 | CFG.MODEL.TARGET_FEATURES = [0] 83 | CFG.MODEL.DDP_FIND_UNUSED_PARAMETERS = True 84 | 85 | # ================= optim ================= # 86 | CFG.TRAIN = EasyDict() 87 | CFG.TRAIN.LOSS = step_loss 88 | CFG.TRAIN.OPTIM = EasyDict() 89 | CFG.TRAIN.OPTIM.TYPE = "Adam" 90 | CFG.TRAIN.OPTIM.PARAM= { 91 | "lr":0.002, 92 | "weight_decay":1.0e-5, 93 | "eps":1.0e-8, 94 | } 95 | CFG.TRAIN.LR_SCHEDULER = EasyDict() 96 | CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" 97 | CFG.TRAIN.LR_SCHEDULER.PARAM= { 98 | "milestones":[1, 18, 36, 54, 72], 99 | "gamma":0.5 100 | } 101 | 102 | # ================= train ================= # 103 | CFG.TRAIN.CLIP_GRAD_PARAM = { 104 | "max_norm": 3.0 105 | } 106 | CFG.TRAIN.NUM_EPOCHS = 100 107 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 108 | "checkpoints", 109 | "_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 110 | ) 111 | # train data 112 | CFG.TRAIN.DATA = EasyDict() 113 | CFG.TRAIN.NULL_VAL = 0.0 114 | # read data 115 | CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME 116 | # dataloader args, optional 117 | CFG.TRAIN.DATA.BATCH_SIZE = 4 118 | CFG.TRAIN.DATA.PREFETCH = False 119 | CFG.TRAIN.DATA.SHUFFLE = True 120 | CFG.TRAIN.DATA.NUM_WORKERS = 2 121 | CFG.TRAIN.DATA.PIN_MEMORY = True 122 | 123 | # ================= validate ================= # 124 | CFG.VAL = EasyDict() 125 | CFG.VAL.INTERVAL = 1 126 | # validating data 127 | CFG.VAL.DATA = EasyDict() 128 | # read data 129 | CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME 130 | # dataloader args, optional 131 | CFG.VAL.DATA.BATCH_SIZE = 4 132 | CFG.VAL.DATA.PREFETCH = False 133 | CFG.VAL.DATA.SHUFFLE = False 134 | CFG.VAL.DATA.NUM_WORKERS = 2 135 | CFG.VAL.DATA.PIN_MEMORY = True 136 | 137 | # ================= test ================= # 138 | CFG.TEST = EasyDict() 139 | CFG.TEST.INTERVAL = 1 140 | # evluation 141 | # test data 142 | CFG.TEST.DATA = EasyDict() 143 | # read data 144 | CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME 145 | # dataloader args, optional 146 | CFG.TEST.DATA.BATCH_SIZE = 4 147 | CFG.TEST.DATA.PREFETCH = False 148 | CFG.TEST.DATA.SHUFFLE = False 149 | CFG.TEST.DATA.NUM_WORKERS = 2 150 | CFG.TEST.DATA.PIN_MEMORY = True 151 | -------------------------------------------------------------------------------- /step/STEP_PEMS08.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | # TODO: remove it when basicts can be installed by pip 6 | sys.path.append(os.path.abspath(__file__ + "/../../..")) 7 | import torch 8 | from easydict import EasyDict 9 | from basicts.utils.serialization import load_adj 10 | 11 | from .step_arch import STEP 12 | from .step_runner import STEPRunner 13 | from .step_loss import step_loss 14 | from .step_data import ForecastingDataset 15 | 16 | 17 | CFG = EasyDict() 18 | 19 | # ================= general ================= # 20 | CFG.DESCRIPTION = "STEP(PEMS08) configuration" 21 | CFG.RUNNER = STEPRunner 22 | CFG.DATASET_CLS = ForecastingDataset 23 | CFG.DATASET_NAME = "PEMS08" 24 | CFG.DATASET_TYPE = "Traffic flow" 25 | CFG.DATASET_INPUT_LEN = 12 26 | CFG.DATASET_OUTPUT_LEN = 12 27 | CFG.DATASET_ARGS = { 28 | "seq_len": 288 * 7 * 2 29 | } 30 | CFG.GPU_NUM = 2 31 | 32 | # ================= environment ================= # 33 | CFG.ENV = EasyDict() 34 | CFG.ENV.SEED = 0 35 | CFG.ENV.CUDNN = EasyDict() 36 | CFG.ENV.CUDNN.ENABLED = True 37 | 38 | # ================= model ================= # 39 | CFG.MODEL = EasyDict() 40 | CFG.MODEL.NAME = "STEP" 41 | CFG.MODEL.ARCH = STEP 42 | CFG.MODEL.PARAM = { 43 | "dataset_name": CFG.DATASET_NAME, 44 | "pre_trained_tsformer_path": "tsformer_ckpt/TSFormer_PEMS08.pt", 45 | "tsformer_args": { 46 | "patch_size":12, 47 | "in_channel":1, 48 | "embed_dim":96, 49 | "num_heads":4, 50 | "mlp_ratio":4, 51 | "dropout":0.1, 52 | "num_token":288 * 7 * 2 / 12, 53 | "mask_ratio":0.75, 54 | "encoder_depth":4, 55 | "decoder_depth":1, 56 | "mode":"forecasting" 57 | }, 58 | "backend_args": { 59 | "num_nodes" : 170, 60 | "support_len" : 2, 61 | "dropout" : 0.3, 62 | "gcn_bool" : True, 63 | "addaptadj" : True, 64 | "aptinit" : None, 65 | "in_dim" : 2, 66 | "out_dim" : 12, 67 | "residual_channels" : 32, 68 | "dilation_channels" : 32, 69 | "skip_channels" : 256, 70 | "end_channels" : 512, 71 | "kernel_size" : 2, 72 | "blocks" : 4, 73 | "layers" : 2 74 | }, 75 | "dgl_args": { 76 | "dataset_name": CFG.DATASET_NAME, 77 | "k": 10, 78 | "input_seq_len": CFG.DATASET_INPUT_LEN, 79 | "output_seq_len": CFG.DATASET_OUTPUT_LEN 80 | } 81 | } 82 | CFG.MODEL.FORWARD_FEATURES = [0, 1, 2] 83 | CFG.MODEL.TARGET_FEATURES = [0] 84 | CFG.MODEL.DDP_FIND_UNUSED_PARAMETERS = True 85 | 86 | # ================= optim ================= # 87 | CFG.TRAIN = EasyDict() 88 | CFG.TRAIN.LOSS = step_loss 89 | CFG.TRAIN.OPTIM = EasyDict() 90 | CFG.TRAIN.OPTIM.TYPE = "Adam" 91 | CFG.TRAIN.OPTIM.PARAM= { 92 | "lr":0.002, 93 | "weight_decay":1.0e-5, 94 | "eps":1.0e-8, 95 | } 96 | CFG.TRAIN.LR_SCHEDULER = EasyDict() 97 | CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" 98 | CFG.TRAIN.LR_SCHEDULER.PARAM= { 99 | "milestones":[1, 18, 36, 54, 72], 100 | "gamma":0.5 101 | } 102 | 103 | # ================= train ================= # 104 | CFG.TRAIN.CLIP_GRAD_PARAM = { 105 | "max_norm": 3.0 106 | } 107 | CFG.TRAIN.NUM_EPOCHS = 100 108 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 109 | "checkpoints", 110 | "_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 111 | ) 112 | # train data 113 | CFG.TRAIN.DATA = EasyDict() 114 | CFG.TRAIN.NULL_VAL = 0.0 115 | # read data 116 | CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME 117 | # dataloader args, optional 118 | CFG.TRAIN.DATA.BATCH_SIZE = 8 119 | CFG.TRAIN.DATA.PREFETCH = False 120 | CFG.TRAIN.DATA.SHUFFLE = True 121 | CFG.TRAIN.DATA.NUM_WORKERS = 2 122 | CFG.TRAIN.DATA.PIN_MEMORY = True 123 | 124 | # ================= validate ================= # 125 | CFG.VAL = EasyDict() 126 | CFG.VAL.INTERVAL = 1 127 | # validating data 128 | CFG.VAL.DATA = EasyDict() 129 | # read data 130 | CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME 131 | # dataloader args, optional 132 | CFG.VAL.DATA.BATCH_SIZE = 8 133 | CFG.VAL.DATA.PREFETCH = False 134 | CFG.VAL.DATA.SHUFFLE = False 135 | CFG.VAL.DATA.NUM_WORKERS = 2 136 | CFG.VAL.DATA.PIN_MEMORY = True 137 | 138 | # ================= test ================= # 139 | CFG.TEST = EasyDict() 140 | CFG.TEST.INTERVAL = 1 141 | # evluation 142 | # test data 143 | CFG.TEST.DATA = EasyDict() 144 | # read data 145 | CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME 146 | # dataloader args, optional 147 | CFG.TEST.DATA.BATCH_SIZE = 8 148 | CFG.TEST.DATA.PREFETCH = False 149 | CFG.TEST.DATA.SHUFFLE = False 150 | CFG.TEST.DATA.NUM_WORKERS = 2 151 | CFG.TEST.DATA.PIN_MEMORY = True 152 | -------------------------------------------------------------------------------- /step/TSFormer_METR-LA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # TODO: remove it when basicts can be installed by pip 5 | sys.path.append(os.path.abspath(__file__ + "/../../..")) 6 | from easydict import EasyDict 7 | from basicts.losses import masked_mae 8 | 9 | from .step_arch import TSFormer 10 | from .step_runner import TSFormerRunner 11 | from .step_data import PretrainingDataset 12 | 13 | 14 | CFG = EasyDict() 15 | 16 | # ================= general ================= # 17 | CFG.DESCRIPTION = "TSFormer(METR-LA) configuration" 18 | CFG.RUNNER = TSFormerRunner 19 | CFG.DATASET_CLS = PretrainingDataset 20 | CFG.DATASET_NAME = "METR-LA" 21 | CFG.DATASET_TYPE = "Traffic speed" 22 | CFG.DATASET_INPUT_LEN = 288 * 7 23 | CFG.DATASET_OUTPUT_LEN = 12 24 | CFG.GPU_NUM = 1 25 | 26 | # ================= environment ================= # 27 | CFG.ENV = EasyDict() 28 | CFG.ENV.SEED = 0 29 | CFG.ENV.CUDNN = EasyDict() 30 | CFG.ENV.CUDNN.ENABLED = True 31 | 32 | # ================= model ================= # 33 | CFG.MODEL = EasyDict() 34 | CFG.MODEL.NAME = "TSFormer" 35 | CFG.MODEL.ARCH = TSFormer 36 | CFG.MODEL.PARAM = { 37 | "patch_size":12, 38 | "in_channel":1, 39 | "embed_dim":96, 40 | "num_heads":4, 41 | "mlp_ratio":4, 42 | "dropout":0.1, 43 | "num_token":288 * 7 / 12, 44 | "mask_ratio":0.75, 45 | "encoder_depth":4, 46 | "decoder_depth":1, 47 | "mode":"pre-train" 48 | } 49 | CFG.MODEL.FORWARD_FEATURES = [0] 50 | CFG.MODEL.TARGET_FEATURES = [0] 51 | 52 | # ================= optim ================= # 53 | CFG.TRAIN = EasyDict() 54 | CFG.TRAIN.LOSS = masked_mae 55 | CFG.TRAIN.OPTIM = EasyDict() 56 | CFG.TRAIN.OPTIM.TYPE = "Adam" 57 | CFG.TRAIN.OPTIM.PARAM= { 58 | "lr":0.0005, 59 | "weight_decay":0, 60 | "eps":1.0e-8, 61 | "betas":(0.9, 0.95) 62 | } 63 | CFG.TRAIN.LR_SCHEDULER = EasyDict() 64 | CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" 65 | CFG.TRAIN.LR_SCHEDULER.PARAM= { 66 | "milestones":[50], 67 | "gamma":0.5 68 | } 69 | 70 | # ================= train ================= # 71 | CFG.TRAIN.CLIP_GRAD_PARAM = { 72 | "max_norm": 5.0 73 | } 74 | CFG.TRAIN.NUM_EPOCHS = 100 75 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 76 | "checkpoints", 77 | "_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 78 | ) 79 | # train data 80 | CFG.TRAIN.DATA = EasyDict() 81 | CFG.TRAIN.NULL_VAL = 0.0 82 | # read data 83 | CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME 84 | # dataloader args, optional 85 | CFG.TRAIN.DATA.BATCH_SIZE = 8 86 | CFG.TRAIN.DATA.PREFETCH = False 87 | CFG.TRAIN.DATA.SHUFFLE = True 88 | CFG.TRAIN.DATA.NUM_WORKERS = 2 89 | CFG.TRAIN.DATA.PIN_MEMORY = True 90 | 91 | # ================= validate ================= # 92 | CFG.VAL = EasyDict() 93 | CFG.VAL.INTERVAL = 1 94 | # validating data 95 | CFG.VAL.DATA = EasyDict() 96 | # read data 97 | CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME 98 | # dataloader args, optional 99 | CFG.VAL.DATA.BATCH_SIZE = 8 100 | CFG.VAL.DATA.PREFETCH = False 101 | CFG.VAL.DATA.SHUFFLE = False 102 | CFG.VAL.DATA.NUM_WORKERS = 2 103 | CFG.VAL.DATA.PIN_MEMORY = True 104 | 105 | # ================= test ================= # 106 | CFG.TEST = EasyDict() 107 | CFG.TEST.INTERVAL = 1 108 | # evluation 109 | # test data 110 | CFG.TEST.DATA = EasyDict() 111 | # read data 112 | CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME 113 | # dataloader args, optional 114 | CFG.TEST.DATA.BATCH_SIZE = 8 115 | CFG.TEST.DATA.PREFETCH = False 116 | CFG.TEST.DATA.SHUFFLE = False 117 | CFG.TEST.DATA.NUM_WORKERS = 2 118 | CFG.TEST.DATA.PIN_MEMORY = True 119 | -------------------------------------------------------------------------------- /step/TSFormer_PEMS-BAY.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # TODO: remove it when basicts can be installed by pip 5 | sys.path.append(os.path.abspath(__file__ + "/../../..")) 6 | from easydict import EasyDict 7 | from basicts.losses import masked_mae 8 | 9 | from .step_arch import TSFormer 10 | from .step_runner import TSFormerRunner 11 | from .step_data import PretrainingDataset 12 | 13 | 14 | CFG = EasyDict() 15 | 16 | # ================= general ================= # 17 | CFG.DESCRIPTION = "TSFormer(PEMS-BAY) configuration" 18 | CFG.RUNNER = TSFormerRunner 19 | CFG.DATASET_CLS = PretrainingDataset 20 | CFG.DATASET_NAME = "PEMS-BAY" 21 | CFG.DATASET_TYPE = "Traffic speed" 22 | CFG.DATASET_INPUT_LEN = 288 * 7 23 | CFG.DATASET_OUTPUT_LEN = 12 24 | CFG.GPU_NUM = 2 25 | 26 | # ================= environment ================= # 27 | CFG.ENV = EasyDict() 28 | CFG.ENV.SEED = 0 29 | CFG.ENV.CUDNN = EasyDict() 30 | CFG.ENV.CUDNN.ENABLED = True 31 | 32 | # ================= model ================= # 33 | CFG.MODEL = EasyDict() 34 | CFG.MODEL.NAME = "TSFormer" 35 | CFG.MODEL.ARCH = TSFormer 36 | CFG.MODEL.PARAM = { 37 | "patch_size":12, 38 | "in_channel":1, 39 | "embed_dim":96, 40 | "num_heads":4, 41 | "mlp_ratio":4, 42 | "dropout":0.1, 43 | "num_token":288 * 7 / 12, 44 | "mask_ratio":0.75, 45 | "encoder_depth":4, 46 | "decoder_depth":1, 47 | "mode":"pre-train" 48 | } 49 | CFG.MODEL.FORWARD_FEATURES = [0] 50 | CFG.MODEL.TARGET_FEATURES = [0] 51 | 52 | # ================= optim ================= # 53 | CFG.TRAIN = EasyDict() 54 | CFG.TRAIN.LOSS = masked_mae 55 | CFG.TRAIN.OPTIM = EasyDict() 56 | CFG.TRAIN.OPTIM.TYPE = "Adam" 57 | CFG.TRAIN.OPTIM.PARAM= { 58 | "lr":0.001, 59 | "weight_decay":0, 60 | "eps":1.0e-8, 61 | "betas":(0.9, 0.95) 62 | } 63 | CFG.TRAIN.LR_SCHEDULER = EasyDict() 64 | CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" 65 | CFG.TRAIN.LR_SCHEDULER.PARAM= { 66 | "milestones":[50], 67 | "gamma":0.5 68 | } 69 | 70 | # ================= train ================= # 71 | CFG.TRAIN.CLIP_GRAD_PARAM = { 72 | "max_norm": 5.0 73 | } 74 | CFG.TRAIN.NUM_EPOCHS = 100 75 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 76 | "checkpoints", 77 | "_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 78 | ) 79 | # train data 80 | CFG.TRAIN.DATA = EasyDict() 81 | CFG.TRAIN.NULL_VAL = 0.0 82 | # read data 83 | CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME 84 | # dataloader args, optional 85 | CFG.TRAIN.DATA.BATCH_SIZE = 16 86 | CFG.TRAIN.DATA.PREFETCH = False 87 | CFG.TRAIN.DATA.SHUFFLE = True 88 | CFG.TRAIN.DATA.NUM_WORKERS = 2 89 | CFG.TRAIN.DATA.PIN_MEMORY = True 90 | 91 | # ================= validate ================= # 92 | CFG.VAL = EasyDict() 93 | CFG.VAL.INTERVAL = 1 94 | # validating data 95 | CFG.VAL.DATA = EasyDict() 96 | # read data 97 | CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME 98 | # dataloader args, optional 99 | CFG.VAL.DATA.BATCH_SIZE = 16 100 | CFG.VAL.DATA.PREFETCH = False 101 | CFG.VAL.DATA.SHUFFLE = False 102 | CFG.VAL.DATA.NUM_WORKERS = 2 103 | CFG.VAL.DATA.PIN_MEMORY = True 104 | 105 | # ================= test ================= # 106 | CFG.TEST = EasyDict() 107 | CFG.TEST.INTERVAL = 1 108 | # evluation 109 | # test data 110 | CFG.TEST.DATA = EasyDict() 111 | # read data 112 | CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME 113 | # dataloader args, optional 114 | CFG.TEST.DATA.BATCH_SIZE = 16 115 | CFG.TEST.DATA.PREFETCH = False 116 | CFG.TEST.DATA.SHUFFLE = False 117 | CFG.TEST.DATA.NUM_WORKERS = 2 118 | CFG.TEST.DATA.PIN_MEMORY = True 119 | -------------------------------------------------------------------------------- /step/TSFormer_PEMS03.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # TODO: remove it when basicts can be installed by pip 5 | sys.path.append(os.path.abspath(__file__ + "/../../..")) 6 | from easydict import EasyDict 7 | from basicts.losses import masked_mae 8 | 9 | from .step_arch import TSFormer 10 | from .step_runner import TSFormerRunner 11 | from .step_data import PretrainingDataset 12 | 13 | 14 | CFG = EasyDict() 15 | 16 | # ================= general ================= # 17 | CFG.DESCRIPTION = "TSFormer(PEMS03) configuration" 18 | CFG.RUNNER = TSFormerRunner 19 | CFG.DATASET_CLS = PretrainingDataset 20 | CFG.DATASET_NAME = "PEMS03" 21 | CFG.DATASET_TYPE = "Traffic flow" 22 | CFG.DATASET_INPUT_LEN = 288 * 7 23 | CFG.DATASET_OUTPUT_LEN = 12 24 | CFG.GPU_NUM = 2 25 | 26 | # ================= environment ================= # 27 | CFG.ENV = EasyDict() 28 | CFG.ENV.SEED = 0 29 | CFG.ENV.CUDNN = EasyDict() 30 | CFG.ENV.CUDNN.ENABLED = True 31 | 32 | # ================= model ================= # 33 | CFG.MODEL = EasyDict() 34 | CFG.MODEL.NAME = "TSFormer" 35 | CFG.MODEL.ARCH = TSFormer 36 | CFG.MODEL.PARAM = { 37 | "patch_size":12, 38 | "in_channel":1, 39 | "embed_dim":96, 40 | "num_heads":4, 41 | "mlp_ratio":4, 42 | "dropout":0.1, 43 | "num_token":288 * 7 / 12, 44 | "mask_ratio":0.75, 45 | "encoder_depth":4, 46 | "decoder_depth":1, 47 | "mode":"pre-train" 48 | } 49 | CFG.MODEL.FORWARD_FEATURES = [0] 50 | CFG.MODEL.TARGET_FEATURES = [0] 51 | 52 | # ================= optim ================= # 53 | CFG.TRAIN = EasyDict() 54 | CFG.TRAIN.LOSS = masked_mae 55 | CFG.TRAIN.OPTIM = EasyDict() 56 | CFG.TRAIN.OPTIM.TYPE = "Adam" 57 | CFG.TRAIN.OPTIM.PARAM= { 58 | "lr":0.001, 59 | "weight_decay":0, 60 | "eps":1.0e-8, 61 | "betas":(0.9, 0.95) 62 | } 63 | CFG.TRAIN.LR_SCHEDULER = EasyDict() 64 | CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" 65 | CFG.TRAIN.LR_SCHEDULER.PARAM= { 66 | "milestones":[50], 67 | "gamma":0.5 68 | } 69 | 70 | # ================= train ================= # 71 | CFG.TRAIN.CLIP_GRAD_PARAM = { 72 | "max_norm": 5.0 73 | } 74 | CFG.TRAIN.NUM_EPOCHS = 200 75 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 76 | "checkpoints", 77 | "_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 78 | ) 79 | # train data 80 | CFG.TRAIN.DATA = EasyDict() 81 | CFG.TRAIN.NULL_VAL = 0.0 82 | # read data 83 | CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME 84 | # dataloader args, optional 85 | CFG.TRAIN.DATA.BATCH_SIZE = 3 86 | CFG.TRAIN.DATA.PREFETCH = False 87 | CFG.TRAIN.DATA.SHUFFLE = True 88 | CFG.TRAIN.DATA.NUM_WORKERS = 2 89 | CFG.TRAIN.DATA.PIN_MEMORY = True 90 | 91 | # ================= validate ================= # 92 | CFG.VAL = EasyDict() 93 | CFG.VAL.INTERVAL = 1 94 | # validating data 95 | CFG.VAL.DATA = EasyDict() 96 | # read data 97 | CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME 98 | # dataloader args, optional 99 | CFG.VAL.DATA.BATCH_SIZE = 3 100 | CFG.VAL.DATA.PREFETCH = False 101 | CFG.VAL.DATA.SHUFFLE = False 102 | CFG.VAL.DATA.NUM_WORKERS = 2 103 | CFG.VAL.DATA.PIN_MEMORY = True 104 | 105 | # ================= test ================= # 106 | CFG.TEST = EasyDict() 107 | CFG.TEST.INTERVAL = 100 108 | # evluation 109 | # test data 110 | CFG.TEST.DATA = EasyDict() 111 | # read data 112 | CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME 113 | # dataloader args, optional 114 | CFG.TEST.DATA.BATCH_SIZE = 3 115 | CFG.TEST.DATA.PREFETCH = False 116 | CFG.TEST.DATA.SHUFFLE = False 117 | CFG.TEST.DATA.NUM_WORKERS = 2 118 | CFG.TEST.DATA.PIN_MEMORY = True 119 | -------------------------------------------------------------------------------- /step/TSFormer_PEMS04.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # TODO: remove it when basicts can be installed by pip 5 | sys.path.append(os.path.abspath(__file__ + "/../../..")) 6 | from easydict import EasyDict 7 | from basicts.losses import masked_mae 8 | 9 | from .step_arch import TSFormer 10 | from .step_runner import TSFormerRunner 11 | from .step_data import PretrainingDataset 12 | 13 | 14 | CFG = EasyDict() 15 | 16 | # ================= general ================= # 17 | CFG.DESCRIPTION = "TSFormer(PEMS04) configuration" 18 | CFG.RUNNER = TSFormerRunner 19 | CFG.DATASET_CLS = PretrainingDataset 20 | CFG.DATASET_NAME = "PEMS04" 21 | CFG.DATASET_TYPE = "Traffic flow" 22 | CFG.DATASET_INPUT_LEN = 288 * 7 * 2 23 | CFG.DATASET_OUTPUT_LEN = 12 24 | CFG.GPU_NUM = 2 25 | 26 | # ================= environment ================= # 27 | CFG.ENV = EasyDict() 28 | CFG.ENV.SEED = 0 29 | CFG.ENV.CUDNN = EasyDict() 30 | CFG.ENV.CUDNN.ENABLED = True 31 | 32 | # ================= model ================= # 33 | CFG.MODEL = EasyDict() 34 | CFG.MODEL.NAME = "TSFormer" 35 | CFG.MODEL.ARCH = TSFormer 36 | CFG.MODEL.PARAM = { 37 | "patch_size":12, 38 | "in_channel":1, 39 | "embed_dim":96, 40 | "num_heads":4, 41 | "mlp_ratio":4, 42 | "dropout":0.1, 43 | "num_token":288 * 7 * 2 / 12, 44 | "mask_ratio":0.75, 45 | "encoder_depth":4, 46 | "decoder_depth":1, 47 | "mode":"pre-train" 48 | } 49 | CFG.MODEL.FORWARD_FEATURES = [0] 50 | CFG.MODEL.TARGET_FEATURES = [0] 51 | 52 | # ================= optim ================= # 53 | CFG.TRAIN = EasyDict() 54 | CFG.TRAIN.LOSS = masked_mae 55 | CFG.TRAIN.OPTIM = EasyDict() 56 | CFG.TRAIN.OPTIM.TYPE = "Adam" 57 | CFG.TRAIN.OPTIM.PARAM= { 58 | "lr":0.001, 59 | "weight_decay":0, 60 | "eps":1.0e-8, 61 | "betas":(0.9, 0.95) 62 | } 63 | CFG.TRAIN.LR_SCHEDULER = EasyDict() 64 | CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" 65 | CFG.TRAIN.LR_SCHEDULER.PARAM= { 66 | "milestones":[50], 67 | "gamma":0.5 68 | } 69 | 70 | # ================= train ================= # 71 | CFG.TRAIN.CLIP_GRAD_PARAM = { 72 | "max_norm": 5.0 73 | } 74 | CFG.TRAIN.NUM_EPOCHS = 200 75 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 76 | "checkpoints", 77 | "_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 78 | ) 79 | # train data 80 | CFG.TRAIN.DATA = EasyDict() 81 | CFG.TRAIN.NULL_VAL = 0.0 82 | # read data 83 | CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME 84 | # dataloader args, optional 85 | CFG.TRAIN.DATA.BATCH_SIZE = 6 86 | CFG.TRAIN.DATA.PREFETCH = False 87 | CFG.TRAIN.DATA.SHUFFLE = True 88 | CFG.TRAIN.DATA.NUM_WORKERS = 2 89 | CFG.TRAIN.DATA.PIN_MEMORY = True 90 | 91 | # ================= validate ================= # 92 | CFG.VAL = EasyDict() 93 | CFG.VAL.INTERVAL = 1 94 | # validating data 95 | CFG.VAL.DATA = EasyDict() 96 | # read data 97 | CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME 98 | # dataloader args, optional 99 | CFG.VAL.DATA.BATCH_SIZE = 8 100 | CFG.VAL.DATA.PREFETCH = False 101 | CFG.VAL.DATA.SHUFFLE = False 102 | CFG.VAL.DATA.NUM_WORKERS = 2 103 | CFG.VAL.DATA.PIN_MEMORY = True 104 | 105 | # ================= test ================= # 106 | CFG.TEST = EasyDict() 107 | CFG.TEST.INTERVAL = 1 108 | # evluation 109 | # test data 110 | CFG.TEST.DATA = EasyDict() 111 | # read data 112 | CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME 113 | # dataloader args, optional 114 | CFG.TEST.DATA.BATCH_SIZE = 8 115 | CFG.TEST.DATA.PREFETCH = False 116 | CFG.TEST.DATA.SHUFFLE = False 117 | CFG.TEST.DATA.NUM_WORKERS = 2 118 | CFG.TEST.DATA.PIN_MEMORY = True 119 | -------------------------------------------------------------------------------- /step/TSFormer_PEMS07.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # TODO: remove it when basicts can be installed by pip 5 | sys.path.append(os.path.abspath(__file__ + "/../../..")) 6 | from easydict import EasyDict 7 | from basicts.losses import masked_mae 8 | 9 | from .step_arch import TSFormer 10 | from .step_runner import TSFormerRunner 11 | from .step_data import PretrainingDataset 12 | 13 | 14 | CFG = EasyDict() 15 | 16 | # ================= general ================= # 17 | CFG.DESCRIPTION = "TSFormer(PEMS07) configuration" 18 | CFG.RUNNER = TSFormerRunner 19 | CFG.DATASET_CLS = PretrainingDataset 20 | CFG.DATASET_NAME = "PEMS07" 21 | CFG.DATASET_TYPE = "Traffic flow" 22 | CFG.DATASET_INPUT_LEN = 288 * 7 23 | CFG.DATASET_OUTPUT_LEN = 12 24 | CFG.GPU_NUM = 2 25 | 26 | # ================= environment ================= # 27 | CFG.ENV = EasyDict() 28 | CFG.ENV.SEED = 0 29 | CFG.ENV.CUDNN = EasyDict() 30 | CFG.ENV.CUDNN.ENABLED = True 31 | 32 | # ================= model ================= # 33 | CFG.MODEL = EasyDict() 34 | CFG.MODEL.NAME = "TSFormer" 35 | CFG.MODEL.ARCH = TSFormer 36 | CFG.MODEL.PARAM = { 37 | "patch_size":12, 38 | "in_channel":1, 39 | "embed_dim":96, 40 | "num_heads":4, 41 | "mlp_ratio":4, 42 | "dropout":0.1, 43 | "num_token":288 * 7 / 12, 44 | "mask_ratio":0.75, 45 | "encoder_depth":4, 46 | "decoder_depth":1, 47 | "mode":"pre-train" 48 | } 49 | CFG.MODEL.FORWARD_FEATURES = [0] 50 | CFG.MODEL.TARGET_FEATURES = [0] 51 | 52 | # ================= optim ================= # 53 | CFG.TRAIN = EasyDict() 54 | CFG.TRAIN.LOSS = masked_mae 55 | CFG.TRAIN.OPTIM = EasyDict() 56 | CFG.TRAIN.OPTIM.TYPE = "Adam" 57 | CFG.TRAIN.OPTIM.PARAM= { 58 | "lr":0.001, 59 | "weight_decay":0, 60 | "eps":1.0e-8, 61 | "betas":(0.9, 0.95) 62 | } 63 | CFG.TRAIN.LR_SCHEDULER = EasyDict() 64 | CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" 65 | CFG.TRAIN.LR_SCHEDULER.PARAM= { 66 | "milestones":[50], 67 | "gamma":0.5 68 | } 69 | 70 | # ================= train ================= # 71 | CFG.TRAIN.CLIP_GRAD_PARAM = { 72 | "max_norm": 5.0 73 | } 74 | CFG.TRAIN.NUM_EPOCHS = 200 75 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 76 | "checkpoints", 77 | "_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 78 | ) 79 | # train data 80 | CFG.TRAIN.DATA = EasyDict() 81 | CFG.TRAIN.NULL_VAL = 0.0 82 | # read data 83 | CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME 84 | # dataloader args, optional 85 | CFG.TRAIN.DATA.BATCH_SIZE = 3 86 | CFG.TRAIN.DATA.PREFETCH = False 87 | CFG.TRAIN.DATA.SHUFFLE = True 88 | CFG.TRAIN.DATA.NUM_WORKERS = 2 89 | CFG.TRAIN.DATA.PIN_MEMORY = True 90 | 91 | # ================= validate ================= # 92 | CFG.VAL = EasyDict() 93 | CFG.VAL.INTERVAL = 1 94 | # validating data 95 | CFG.VAL.DATA = EasyDict() 96 | # read data 97 | CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME 98 | # dataloader args, optional 99 | CFG.VAL.DATA.BATCH_SIZE = 3 100 | CFG.VAL.DATA.PREFETCH = False 101 | CFG.VAL.DATA.SHUFFLE = False 102 | CFG.VAL.DATA.NUM_WORKERS = 2 103 | CFG.VAL.DATA.PIN_MEMORY = True 104 | 105 | # ================= test ================= # 106 | CFG.TEST = EasyDict() 107 | CFG.TEST.INTERVAL = 100 108 | # evluation 109 | # test data 110 | CFG.TEST.DATA = EasyDict() 111 | # read data 112 | CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME 113 | # dataloader args, optional 114 | CFG.TEST.DATA.BATCH_SIZE = 3 115 | CFG.TEST.DATA.PREFETCH = False 116 | CFG.TEST.DATA.SHUFFLE = False 117 | CFG.TEST.DATA.NUM_WORKERS = 2 118 | CFG.TEST.DATA.PIN_MEMORY = True 119 | -------------------------------------------------------------------------------- /step/TSFormer_PEMS08.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # TODO: remove it when basicts can be installed by pip 5 | sys.path.append(os.path.abspath(__file__ + "/../../..")) 6 | from easydict import EasyDict 7 | from basicts.losses import masked_mae 8 | 9 | from .step_arch import TSFormer 10 | from .step_runner import TSFormerRunner 11 | from .step_data import PretrainingDataset 12 | 13 | 14 | CFG = EasyDict() 15 | 16 | # ================= general ================= # 17 | CFG.DESCRIPTION = "TSFormer(PEMS08) configuration" 18 | CFG.RUNNER = TSFormerRunner 19 | CFG.DATASET_CLS = PretrainingDataset 20 | CFG.DATASET_NAME = "PEMS08" 21 | CFG.DATASET_TYPE = "Traffic flow" 22 | CFG.DATASET_INPUT_LEN = 288 * 7 * 2 23 | CFG.DATASET_OUTPUT_LEN = 12 24 | CFG.GPU_NUM = 2 25 | 26 | # ================= environment ================= # 27 | CFG.ENV = EasyDict() 28 | CFG.ENV.SEED = 0 29 | CFG.ENV.CUDNN = EasyDict() 30 | CFG.ENV.CUDNN.ENABLED = True 31 | 32 | # ================= model ================= # 33 | CFG.MODEL = EasyDict() 34 | CFG.MODEL.NAME = "TSFormer" 35 | CFG.MODEL.ARCH = TSFormer 36 | CFG.MODEL.PARAM = { 37 | "patch_size":12, 38 | "in_channel":1, 39 | "embed_dim":96, 40 | "num_heads":4, 41 | "mlp_ratio":4, 42 | "dropout":0.1, 43 | "num_token":288 * 7 * 2 / 12, 44 | "mask_ratio":0.75, 45 | "encoder_depth":4, 46 | "decoder_depth":1, 47 | "mode":"pre-train" 48 | } 49 | CFG.MODEL.FORWARD_FEATURES = [0] 50 | CFG.MODEL.TARGET_FEATURES = [0] 51 | 52 | # ================= optim ================= # 53 | CFG.TRAIN = EasyDict() 54 | CFG.TRAIN.LOSS = masked_mae 55 | CFG.TRAIN.OPTIM = EasyDict() 56 | CFG.TRAIN.OPTIM.TYPE = "Adam" 57 | CFG.TRAIN.OPTIM.PARAM= { 58 | "lr":0.001, 59 | "weight_decay":0, 60 | "eps":1.0e-8, 61 | "betas":(0.9, 0.95) 62 | } 63 | CFG.TRAIN.LR_SCHEDULER = EasyDict() 64 | CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" 65 | CFG.TRAIN.LR_SCHEDULER.PARAM= { 66 | "milestones":[50], 67 | "gamma":0.5 68 | } 69 | 70 | # ================= train ================= # 71 | CFG.TRAIN.CLIP_GRAD_PARAM = { 72 | "max_norm": 5.0 73 | } 74 | CFG.TRAIN.NUM_EPOCHS = 200 75 | CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( 76 | "checkpoints", 77 | "_".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) 78 | ) 79 | # train data 80 | CFG.TRAIN.DATA = EasyDict() 81 | CFG.TRAIN.NULL_VAL = 0.0 82 | # read data 83 | CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME 84 | # dataloader args, optional 85 | CFG.TRAIN.DATA.BATCH_SIZE = 6 86 | CFG.TRAIN.DATA.PREFETCH = False 87 | CFG.TRAIN.DATA.SHUFFLE = True 88 | CFG.TRAIN.DATA.NUM_WORKERS = 2 89 | CFG.TRAIN.DATA.PIN_MEMORY = True 90 | 91 | # ================= validate ================= # 92 | CFG.VAL = EasyDict() 93 | CFG.VAL.INTERVAL = 1 94 | # validating data 95 | CFG.VAL.DATA = EasyDict() 96 | # read data 97 | CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME 98 | # dataloader args, optional 99 | CFG.VAL.DATA.BATCH_SIZE = 8 100 | CFG.VAL.DATA.PREFETCH = False 101 | CFG.VAL.DATA.SHUFFLE = False 102 | CFG.VAL.DATA.NUM_WORKERS = 2 103 | CFG.VAL.DATA.PIN_MEMORY = True 104 | 105 | # ================= test ================= # 106 | CFG.TEST = EasyDict() 107 | CFG.TEST.INTERVAL = 1 108 | # evluation 109 | # test data 110 | CFG.TEST.DATA = EasyDict() 111 | # read data 112 | CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME 113 | # dataloader args, optional 114 | CFG.TEST.DATA.BATCH_SIZE = 8 115 | CFG.TEST.DATA.PREFETCH = False 116 | CFG.TEST.DATA.SHUFFLE = False 117 | CFG.TEST.DATA.NUM_WORKERS = 2 118 | CFG.TEST.DATA.PIN_MEMORY = True 119 | -------------------------------------------------------------------------------- /step/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from argparse import ArgumentParser 4 | 5 | # TODO: remove it when basicts can be installed by pip 6 | sys.path.append(os.path.abspath(__file__ + "/../..")) 7 | import torch 8 | from basicts import launch_training 9 | 10 | torch.set_num_threads(2) # aviod high cpu avg usage 11 | 12 | 13 | def parse_args(): 14 | parser = ArgumentParser(description="Run time series forecasting model in BasicTS framework!") 15 | # parser.add_argument("-c", "--cfg", default="step/TSFormer_METR-LA.py", help="training config") 16 | # parser.add_argument("-c", "--cfg", default="step/÷STEP_METR-LA.py", help="training config") 17 | 18 | # parser.add_argument("-c", "--cfg", default="step/TSFormer_PEMS04.py", help="training config") 19 | # parser.add_argument("-c", "--cfg", default="step/STEP_PEMS04.py", help="training config") 20 | 21 | # parser.add_argument("-c", "--cfg", default="step/TSFormer_PEMS-BAY.py", help="training config") 22 | # parser.add_argument("-c", "--cfg", default="step/STEP_PEMS-BAY.py", help="training config") 23 | 24 | # parser.add_argument("-c", "--cfg", default="step/TSFormer_PEMS08.py", help="training config") 25 | parser.add_argument("-c", "--cfg", default="step/STEP_PEMS08.py", help="training config") 26 | 27 | parser.add_argument("--gpus", default="0", help="visible gpus") 28 | return parser.parse_args() 29 | 30 | if __name__ == "__main__": 31 | args = parse_args() 32 | 33 | launch_training(args.cfg, args.gpus) 34 | -------------------------------------------------------------------------------- /step/step_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .step import STEP 2 | from .tsformer import TSFormer 3 | 4 | __all__ = ["TSFormer", "STEP"] 5 | -------------------------------------------------------------------------------- /step/step_arch/graphwavenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import GraphWaveNet 2 | 3 | __all__ = ['GraphWaveNet'] 4 | -------------------------------------------------------------------------------- /step/step_arch/similarity.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | def batch_cosine_similarity(x, y): 7 | # 计算分母 8 | l2_x = torch.norm(x, dim=2, p=2) + 1e-7 # avoid 0, l2 norm, num_heads x batch_size x hidden_dim==>num_heads x batch_size 9 | l2_y = torch.norm(y, dim=2, p=2) + 1e-7 # avoid 0, l2 norm, num_heads x batch_size x hidden_dim==>num_heads x batch_size 10 | l2_m = torch.matmul(l2_x.unsqueeze(dim=2), l2_y.unsqueeze(dim=2).transpose(1, 2)) 11 | # 计算分子 12 | l2_z = torch.matmul(x, y.transpose(1, 2)) 13 | # cos similarity affinity matrix 14 | cos_affnity = l2_z / l2_m 15 | adj = cos_affnity 16 | return adj 17 | 18 | def batch_dot_similarity(x, y): 19 | QKT = torch.bmm(x, y.transpose(-1, -2)) / math.sqrt(x.shape[2]) 20 | W = torch.softmax(QKT, dim=-1) 21 | return W 22 | -------------------------------------------------------------------------------- /step/step_arch/step.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .tsformer import TSFormer 5 | from .graphwavenet import GraphWaveNet 6 | from .discrete_graph_learning import DiscreteGraphLearning 7 | 8 | 9 | class STEP(nn.Module): 10 | """Pre-training Enhanced Spatial-temporal Graph Neural Network for Multivariate Time Series Forecasting""" 11 | 12 | def __init__(self, dataset_name, pre_trained_tsformer_path, tsformer_args, backend_args, dgl_args): 13 | super().__init__() 14 | self.dataset_name = dataset_name 15 | self.pre_trained_tsformer_path = pre_trained_tsformer_path 16 | 17 | # iniitalize the tsformer and backend models 18 | self.tsformer = TSFormer(**tsformer_args) 19 | self.backend = GraphWaveNet(**backend_args) 20 | 21 | # load pre-trained tsformer 22 | self.load_pre_trained_model() 23 | 24 | # discrete graph learning 25 | self.discrete_graph_learning = DiscreteGraphLearning(**dgl_args) 26 | 27 | def load_pre_trained_model(self): 28 | """Load pre-trained model""" 29 | 30 | # load parameters 31 | checkpoint_dict = torch.load(self.pre_trained_tsformer_path) 32 | self.tsformer.load_state_dict(checkpoint_dict["model_state_dict"]) 33 | # freeze parameters 34 | for param in self.tsformer.parameters(): 35 | param.requires_grad = False 36 | 37 | def forward(self, history_data: torch.Tensor, long_history_data: torch.Tensor, future_data: torch.Tensor, batch_seen: int, epoch: int, **kwargs) -> torch.Tensor: 38 | """Feed forward of STEP. 39 | 40 | Args: 41 | history_data (torch.Tensor): Short-term historical data. shape: [B, L, N, 3] 42 | long_history_data (torch.Tensor): Long-term historical data. shape: [B, L * P, N, 3] 43 | future_data (torch.Tensor): future data 44 | batch_seen (int): number of batches that have been seen 45 | epoch (int): number of epochs 46 | 47 | Returns: 48 | torch.Tensor: prediction with shape [B, N, L]. 49 | torch.Tensor: the Bernoulli distribution parameters with shape [B, N, N]. 50 | torch.Tensor: the kNN graph with shape [B, N, N], which is used to guide the training of the dependency graph. 51 | """ 52 | 53 | # reshape 54 | short_term_history = history_data # [B, L, N, 1] 55 | long_term_history = long_history_data 56 | 57 | # STEP 58 | batch_size, _, num_nodes, _ = short_term_history.shape 59 | 60 | # discrete graph learning & feed forward of TSFormer 61 | bernoulli_unnorm, hidden_states, adj_knn, sampled_adj = self.discrete_graph_learning(long_term_history, self.tsformer) 62 | 63 | # enhancing downstream STGNNs 64 | hidden_states = hidden_states[:, :, -1, :] 65 | y_hat = self.backend(short_term_history, hidden_states=hidden_states, sampled_adj=sampled_adj).transpose(1, 2) 66 | 67 | # graph structure loss coefficient 68 | if epoch is not None: 69 | gsl_coefficient = 1 / (int(epoch/6)+1) 70 | else: 71 | gsl_coefficient = 0 72 | return y_hat.unsqueeze(-1), bernoulli_unnorm.softmax(-1)[..., 0].clone().reshape(batch_size, num_nodes, num_nodes), adj_knn, gsl_coefficient 73 | -------------------------------------------------------------------------------- /step/step_arch/tsformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .tsformer import TSFormer 2 | 3 | __all__ = ["TSFormer"] 4 | -------------------------------------------------------------------------------- /step/step_arch/tsformer/mask.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from torch import nn 4 | 5 | 6 | class MaskGenerator(nn.Module): 7 | """Mask generator.""" 8 | 9 | def __init__(self, num_tokens, mask_ratio): 10 | super().__init__() 11 | self.num_tokens = num_tokens 12 | self.mask_ratio = mask_ratio 13 | self.sort = True 14 | 15 | def uniform_rand(self): 16 | mask = list(range(int(self.num_tokens))) 17 | random.shuffle(mask) 18 | mask_len = int(self.num_tokens * self.mask_ratio) 19 | self.masked_tokens = mask[:mask_len] 20 | self.unmasked_tokens = mask[mask_len:] 21 | if self.sort: 22 | self.masked_tokens = sorted(self.masked_tokens) 23 | self.unmasked_tokens = sorted(self.unmasked_tokens) 24 | return self.unmasked_tokens, self.masked_tokens 25 | 26 | def forward(self): 27 | self.unmasked_tokens, self.masked_tokens = self.uniform_rand() 28 | return self.unmasked_tokens, self.masked_tokens 29 | -------------------------------------------------------------------------------- /step/step_arch/tsformer/patch.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class PatchEmbedding(nn.Module): 5 | """Patchify time series.""" 6 | 7 | def __init__(self, patch_size, in_channel, embed_dim, norm_layer): 8 | super().__init__() 9 | self.output_channel = embed_dim 10 | self.len_patch = patch_size # the L 11 | self.input_channel = in_channel 12 | self.output_channel = embed_dim 13 | self.input_embedding = nn.Conv2d( 14 | in_channel, 15 | embed_dim, 16 | kernel_size=(self.len_patch, 1), 17 | stride=(self.len_patch, 1)) 18 | self.norm_layer = norm_layer if norm_layer is not None else nn.Identity() 19 | 20 | def forward(self, long_term_history): 21 | """ 22 | Args: 23 | long_term_history (torch.Tensor): Very long-term historical MTS with shape [B, N, 1, P * L], 24 | which is used in the TSFormer. 25 | P is the number of segments (patches). 26 | 27 | Returns: 28 | torch.Tensor: patchified time series with shape [B, N, d, P] 29 | """ 30 | 31 | batch_size, num_nodes, num_feat, len_time_series = long_term_history.shape 32 | long_term_history = long_term_history.unsqueeze(-1) # B, N, C, L, 1 33 | # B*N, C, L, 1 34 | long_term_history = long_term_history.reshape(batch_size*num_nodes, num_feat, len_time_series, 1) 35 | # B*N, d, L/P, 1 36 | output = self.input_embedding(long_term_history) 37 | # norm 38 | output = self.norm_layer(output) 39 | # reshape 40 | output = output.squeeze(-1).view(batch_size, num_nodes, self.output_channel, -1) # B, N, d, P 41 | assert output.shape[-1] == len_time_series / self.len_patch 42 | return output 43 | -------------------------------------------------------------------------------- /step/step_arch/tsformer/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class PositionalEncoding(nn.Module): 6 | """Positional encoding.""" 7 | 8 | def __init__(self, hidden_dim, dropout=0.1, max_len: int = 1000): 9 | super().__init__() 10 | self.dropout = nn.Dropout(p=dropout) 11 | self.position_embedding = nn.Parameter(torch.empty(max_len, hidden_dim), requires_grad=True) 12 | 13 | def forward(self, input_data, index=None, abs_idx=None): 14 | """Positional encoding 15 | 16 | Args: 17 | input_data (torch.tensor): input sequence with shape [B, N, P, d]. 18 | index (list or None): add positional embedding by index. 19 | 20 | Returns: 21 | torch.tensor: output sequence 22 | """ 23 | 24 | batch_size, num_nodes, num_patches, num_feat = input_data.shape 25 | input_data = input_data.view(batch_size*num_nodes, num_patches, num_feat) 26 | # positional encoding 27 | if index is None: 28 | pe = self.position_embedding[:input_data.size(1), :].unsqueeze(0) 29 | else: 30 | pe = self.position_embedding[index].unsqueeze(0) 31 | input_data = input_data + pe 32 | input_data = self.dropout(input_data) 33 | # reshape 34 | input_data = input_data.view(batch_size, num_nodes, num_patches, num_feat) 35 | return input_data 36 | -------------------------------------------------------------------------------- /step/step_arch/tsformer/transformer_layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 4 | 5 | 6 | class TransformerLayers(nn.Module): 7 | def __init__(self, hidden_dim, nlayers, mlp_ratio, num_heads=4, dropout=0.1): 8 | super().__init__() 9 | self.d_model = hidden_dim 10 | encoder_layers = TransformerEncoderLayer(hidden_dim, num_heads, hidden_dim*mlp_ratio, dropout) 11 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 12 | 13 | def forward(self, src): 14 | B, N, L, D = src.shape 15 | src = src * math.sqrt(self.d_model) 16 | src = src.view(B*N, L, D) 17 | src = src.transpose(0, 1) 18 | output = self.transformer_encoder(src, mask=None) 19 | output = output.transpose(0, 1).view(B, N, L, D) 20 | return output 21 | -------------------------------------------------------------------------------- /step/step_data/__init__.py: -------------------------------------------------------------------------------- 1 | from .pretraining_dataset import PretrainingDataset 2 | from .forecasting_dataset import ForecastingDataset 3 | 4 | __all__ = ["PretrainingDataset", "ForecastingDataset"] 5 | -------------------------------------------------------------------------------- /step/step_data/forecasting_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from basicts.utils import load_pkl 6 | 7 | 8 | class ForecastingDataset(Dataset): 9 | """Time series forecasting dataset.""" 10 | 11 | def __init__(self, data_file_path: str, index_file_path: str, mode: str, seq_len:int) -> None: 12 | """Init the dataset in the forecasting stage. 13 | 14 | Args: 15 | data_file_path (str): data file path. 16 | index_file_path (str): index file path. 17 | mode (str): train, valid, or test. 18 | seq_len (int): the length of long term historical data. 19 | """ 20 | 21 | super().__init__() 22 | assert mode in ["train", "valid", "test"], "error mode" 23 | self._check_if_file_exists(data_file_path, index_file_path) 24 | # read raw data (normalized) 25 | data = load_pkl(data_file_path) 26 | processed_data = data["processed_data"] 27 | self.data = torch.from_numpy(processed_data).float() 28 | # read index 29 | self.index = load_pkl(index_file_path)[mode] 30 | # length of long term historical data 31 | self.seq_len = seq_len 32 | # mask 33 | self.mask = torch.zeros(self.seq_len, self.data.shape[1], self.data.shape[2]) 34 | 35 | def _check_if_file_exists(self, data_file_path: str, index_file_path: str): 36 | """Check if data file and index file exist. 37 | 38 | Args: 39 | data_file_path (str): data file path 40 | index_file_path (str): index file path 41 | 42 | Raises: 43 | FileNotFoundError: no data file 44 | FileNotFoundError: no index file 45 | """ 46 | 47 | if not os.path.isfile(data_file_path): 48 | raise FileNotFoundError("BasicTS can not find data file {0}".format(data_file_path)) 49 | if not os.path.isfile(index_file_path): 50 | raise FileNotFoundError("BasicTS can not find index file {0}".format(index_file_path)) 51 | 52 | def __getitem__(self, index: int) -> tuple: 53 | """Get a sample. 54 | 55 | Args: 56 | index (int): the iteration index (not the self.index) 57 | 58 | Returns: 59 | tuple: (future_data, history_data), where the shape of each is L x N x C. 60 | """ 61 | 62 | idx = list(self.index[index]) 63 | 64 | history_data = self.data[idx[0]:idx[1]] # 12 65 | future_data = self.data[idx[1]:idx[2]] # 12 66 | if idx[1] - self.seq_len < 0: 67 | long_history_data = self.mask 68 | else: 69 | long_history_data = self.data[idx[1] - self.seq_len:idx[1]] # 11 70 | 71 | return future_data, history_data, long_history_data 72 | 73 | def __len__(self): 74 | """Dataset length 75 | 76 | Returns: 77 | int: dataset length 78 | """ 79 | 80 | return len(self.index) 81 | -------------------------------------------------------------------------------- /step/step_data/pretraining_dataset.py: -------------------------------------------------------------------------------- 1 | from basicts.data import TimeSeriesForecastingDataset as PretrainingDataset 2 | -------------------------------------------------------------------------------- /step/step_loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .step_loss import step_loss 2 | 3 | __all__ = ["step_loss"] 4 | -------------------------------------------------------------------------------- /step/step_loss/step_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | from basicts.losses import masked_mae 4 | 5 | def step_loss(prediction, real_value, theta, priori_adj, gsl_coefficient, null_val=np.nan): 6 | # graph structure learning loss 7 | B, N, N = theta.shape 8 | theta = theta.view(B, N*N) 9 | tru = priori_adj.view(B, N*N) 10 | BCE_loss = nn.BCELoss() 11 | loss_graph = BCE_loss(theta, tru) 12 | # prediction loss 13 | loss_pred = masked_mae(preds=prediction, labels=real_value, null_val=null_val) 14 | # final loss 15 | loss = loss_pred + loss_graph * gsl_coefficient 16 | return loss 17 | -------------------------------------------------------------------------------- /step/step_runner/__init__.py: -------------------------------------------------------------------------------- 1 | from .tsformer_runner import TSFormerRunner 2 | from .step_runner import STEPRunner 3 | 4 | __all__ = ["TSFormerRunner", "STEPRunner"] 5 | -------------------------------------------------------------------------------- /step/step_runner/step_runner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from basicts.runners import BaseTimeSeriesForecastingRunner 4 | from basicts.metrics import masked_mae, masked_rmse, masked_mape 5 | 6 | 7 | class STEPRunner(BaseTimeSeriesForecastingRunner): 8 | def __init__(self, cfg: dict): 9 | super().__init__(cfg) 10 | self.metrics = cfg.get("METRICS", {"MAE": masked_mae, "RMSE": masked_rmse, "MAPE": masked_mape}) 11 | self.forward_features = cfg["MODEL"].get("FORWARD_FEATURES", None) 12 | self.target_features = cfg["MODEL"].get("TARGET_FEATURES", None) 13 | 14 | def select_input_features(self, data: torch.Tensor) -> torch.Tensor: 15 | """Select input features and reshape data to fit the target model. 16 | 17 | Args: 18 | data (torch.Tensor): input history data, shape [B, L, N, C]. 19 | 20 | Returns: 21 | torch.Tensor: reshaped data 22 | """ 23 | 24 | # select feature using self.forward_features 25 | if self.forward_features is not None: 26 | data = data[:, :, :, self.forward_features] 27 | return data 28 | 29 | def select_target_features(self, data: torch.Tensor) -> torch.Tensor: 30 | """Select target features and reshape data back to the BasicTS framework 31 | 32 | Args: 33 | data (torch.Tensor): prediction of the model with arbitrary shape. 34 | 35 | Returns: 36 | torch.Tensor: reshaped data with shape [B, L, N, C] 37 | """ 38 | 39 | # select feature using self.target_features 40 | data = data[:, :, :, self.target_features] 41 | return data 42 | 43 | def forward(self, data: tuple, epoch:int = None, iter_num: int = None, train:bool = True, **kwargs) -> tuple: 44 | """feed forward process for train, val, and test. Note that the outputs are NOT re-scaled. 45 | 46 | Args: 47 | data (tuple): data (future data, history data). [B, L, N, C] for each of them 48 | epoch (int, optional): epoch number. Defaults to None. 49 | iter_num (int, optional): iteration number. Defaults to None. 50 | train (bool, optional): if in the training process. Defaults to True. 51 | 52 | Returns: 53 | tuple: (prediction, real_value) 54 | """ 55 | 56 | # preprocess 57 | future_data, history_data, long_history_data = data 58 | history_data = self.to_running_device(history_data) # B, L, N, C 59 | long_history_data = self.to_running_device(long_history_data) # B, L, N, C 60 | future_data = self.to_running_device(future_data) # B, L, N, C 61 | 62 | history_data = self.select_input_features(history_data) 63 | long_history_data = self.select_input_features(long_history_data) 64 | 65 | # feed forward 66 | prediction, pred_adj, prior_adj, gsl_coefficient = self.model(history_data=history_data, long_history_data=long_history_data, future_data=None, batch_seen=iter_num, epoch=epoch) 67 | 68 | batch_size, length, num_nodes, _ = future_data.shape 69 | assert list(prediction.shape)[:3] == [batch_size, length, num_nodes], \ 70 | "error shape of the output, edit the forward function to reshape it to [B, L, N, C]" 71 | 72 | # post process 73 | prediction = self.select_target_features(prediction) 74 | real_value = self.select_target_features(future_data) 75 | return prediction, real_value, pred_adj, prior_adj, gsl_coefficient 76 | -------------------------------------------------------------------------------- /step/step_runner/tsformer_runner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from easytorch.utils.dist import master_only 4 | from basicts.data.registry import SCALER_REGISTRY 5 | from basicts.runners import BaseTimeSeriesForecastingRunner 6 | 7 | 8 | class TSFormerRunner(BaseTimeSeriesForecastingRunner): 9 | def __init__(self, cfg: dict): 10 | super().__init__(cfg) 11 | self.forward_features = cfg["MODEL"].get("FORWARD_FEATURES", None) 12 | self.target_features = cfg["MODEL"].get("TARGET_FEATURES", None) 13 | 14 | def select_input_features(self, data: torch.Tensor) -> torch.Tensor: 15 | """Select input features and reshape data to fit the target model. 16 | 17 | Args: 18 | data (torch.Tensor): input history data, shape [B, L, N, C]. 19 | 20 | Returns: 21 | torch.Tensor: reshaped data 22 | """ 23 | 24 | # select feature using self.forward_features 25 | if self.forward_features is not None: 26 | data = data[:, :, :, self.forward_features] 27 | return data 28 | 29 | def select_target_features(self, data: torch.Tensor) -> torch.Tensor: 30 | """Select target features and reshape data back to the BasicTS framework 31 | 32 | Args: 33 | data (torch.Tensor): prediction of the model with arbitrary shape. 34 | 35 | Returns: 36 | torch.Tensor: reshaped data with shape [B, L, N, C] 37 | """ 38 | 39 | # select feature using self.target_features 40 | data = data[:, :, :, self.target_features] 41 | return data 42 | 43 | def forward(self, data: tuple, epoch:int = None, iter_num: int = None, train:bool = True, **kwargs) -> tuple: 44 | """feed forward process for train, val, and test. Note that the outputs are NOT re-scaled. 45 | 46 | Args: 47 | data (tuple): data (future data, history data). [B, L, N, C] for each of them 48 | epoch (int, optional): epoch number. Defaults to None. 49 | iter_num (int, optional): iteration number. Defaults to None. 50 | train (bool, optional): if in the training process. Defaults to True. 51 | 52 | Returns: 53 | tuple: (prediction, real_value) 54 | """ 55 | 56 | # preprocess 57 | future_data, history_data = data 58 | history_data = self.to_running_device(history_data) # B, L, N, C 59 | future_data = self.to_running_device(future_data) # B, L, N, C 60 | batch_size, length, num_nodes, _ = future_data.shape 61 | 62 | history_data = self.select_input_features(history_data) 63 | 64 | # feed forward 65 | reconstruction_masked_tokens, label_masked_tokens = self.model(history_data=history_data, future_data=None, batch_seen=iter_num, epoch=epoch) 66 | # assert list(prediction_data.shape)[:3] == [batch_size, length, num_nodes], \ 67 | # "error shape of the output, edit the forward function to reshape it to [B, L, N, C]" 68 | # post process 69 | # prediction = self.select_target_features(prediction_data) 70 | # real_value = self.select_target_features(future_data) 71 | return reconstruction_masked_tokens, label_masked_tokens 72 | 73 | @torch.no_grad() 74 | @master_only 75 | def test(self): 76 | """Evaluate the model. 77 | 78 | Args: 79 | train_epoch (int, optional): current epoch if in training process. 80 | """ 81 | 82 | for _, data in enumerate(self.test_data_loader): 83 | forward_return = self.forward(data=data, epoch=None, iter_num=None, train=False) 84 | # re-scale data 85 | prediction_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[0], **self.scaler["args"]) 86 | real_value_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[1], **self.scaler["args"]) 87 | # metrics 88 | for metric_name, metric_func in self.metrics.items(): 89 | metric_item = metric_func(prediction_rescaled, real_value_rescaled, null_val=self.null_val) 90 | self.update_epoch_meter("test_"+metric_name, metric_item.item()) 91 | -------------------------------------------------------------------------------- /test/test_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.abspath(__file__ + "/../..")) 4 | from argparse import ArgumentParser 5 | 6 | from easytorch import launch_runner, Runner 7 | 8 | 9 | def parse_args(): 10 | parser = ArgumentParser(description='Welcome to EasyTorch!') 11 | parser.add_argument('-c', '--cfg', default="step/STEP_METR-LA.py", help='training config') 12 | parser.add_argument('--ckpt', default="checkpoints/STEP_100/4831df1c147dd7dbb643ef143092743d/STEP_best_val_MAE.pt", help='ckpt path. if it is None, load default ckpt in ckpt save dir', type=str) 13 | parser.add_argument("--gpus", default="0", help="visible gpus") 14 | return parser.parse_args() 15 | 16 | 17 | def main(cfg: dict, runner: Runner, ckpt: str = None): 18 | # init logger 19 | runner.init_logger(logger_name='easytorch-inference', log_file_name='validate_result') 20 | 21 | runner.load_model(ckpt_path=ckpt) 22 | 23 | runner.test_process(cfg) 24 | 25 | 26 | if __name__ == '__main__': 27 | args = parse_args() 28 | try: 29 | launch_runner(args.cfg, main, (args.ckpt,), devices=args.gpus) 30 | except TypeError as e: 31 | if "launch_runner() got an unexpected keyword argument" in repr(e): 32 | # NOTE: for earlier easytorch version 33 | launch_runner(args.cfg, main, (args.ckpt,), gpus=args.gpus) 34 | else: 35 | raise e 36 | -------------------------------------------------------------------------------- /tsformer_ckpt/TSFormer_METR-LA.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GestaltCogTeam/STEP/566e2738da2d83f055718d8edb609ad8dc325204/tsformer_ckpt/TSFormer_METR-LA.pt -------------------------------------------------------------------------------- /tsformer_ckpt/TSFormer_PEMS-BAY.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GestaltCogTeam/STEP/566e2738da2d83f055718d8edb609ad8dc325204/tsformer_ckpt/TSFormer_PEMS-BAY.pt -------------------------------------------------------------------------------- /tsformer_ckpt/TSFormer_PEMS03.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GestaltCogTeam/STEP/566e2738da2d83f055718d8edb609ad8dc325204/tsformer_ckpt/TSFormer_PEMS03.pt -------------------------------------------------------------------------------- /tsformer_ckpt/TSFormer_PEMS04.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GestaltCogTeam/STEP/566e2738da2d83f055718d8edb609ad8dc325204/tsformer_ckpt/TSFormer_PEMS04.pt -------------------------------------------------------------------------------- /tsformer_ckpt/TSFormer_PEMS07.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GestaltCogTeam/STEP/566e2738da2d83f055718d8edb609ad8dc325204/tsformer_ckpt/TSFormer_PEMS07.pt -------------------------------------------------------------------------------- /tsformer_ckpt/TSFormer_PEMS08.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GestaltCogTeam/STEP/566e2738da2d83f055718d8edb609ad8dc325204/tsformer_ckpt/TSFormer_PEMS08.pt --------------------------------------------------------------------------------