├── LICENSE.txt ├── README.md ├── SAVE └── pretrain │ ├── GWN │ └── GWN_P8437.pth │ ├── MTGNN │ └── MTGNN_P8437.pth │ ├── PDFormer │ └── PDFormer_P8437.pth │ └── STGCN │ └── STGCN_P8437.pth ├── conf ├── AGCRN │ ├── CA_District5.conf │ ├── NYC_BIKE.conf │ ├── PEMS07M.conf │ └── chengdu_didi.conf ├── ASTGCN │ ├── CA_District5.conf │ ├── NYC_BIKE.conf │ ├── PEMS07M.conf │ └── chengdu_didi.conf ├── FlashST │ ├── Params_pretrain.py │ └── config.conf ├── GWN │ ├── CA_District5.conf │ ├── NYC_BIKE.conf │ ├── PEMS03.conf │ ├── PEMS04.conf │ ├── PEMS07.conf │ ├── PEMS07M.conf │ ├── PEMS08.conf │ └── chengdu_didi.conf ├── MSDR │ ├── CA_District5.conf │ ├── NYC_BIKE.conf │ ├── PEMS07M.conf │ └── chengdu_didi.conf ├── MTGNN │ ├── CA_District5.conf │ ├── NYC_BIKE.conf │ ├── PEMS03.conf │ ├── PEMS04.conf │ ├── PEMS07.conf │ ├── PEMS07M.conf │ ├── PEMS08.conf │ └── chengdu_didi.conf ├── PDFormer │ ├── CA_District5.conf │ ├── NYC_BIKE.conf │ ├── PEMS03.conf │ ├── PEMS04.conf │ ├── PEMS07.conf │ ├── PEMS07M.conf │ ├── PEMS08.conf │ └── chengdu_didi.conf ├── ST-WA │ ├── CA_District5.conf │ ├── NYC_BIKE.conf │ ├── PEMS07M.conf │ └── chengdu_didi.conf ├── STFGNN │ ├── CA_District5.conf │ ├── NYC_BIKE.conf │ ├── PEMS07M.conf │ └── chengdu_didi.conf ├── STGCN │ ├── CA_District5.conf │ ├── NYC_BIKE.conf │ ├── PEMS03.conf │ ├── PEMS04.conf │ ├── PEMS07.conf │ ├── PEMS07M.conf │ ├── PEMS08.conf │ └── chengdu_didi.conf ├── STSGCN │ ├── CA_District5.conf │ ├── NYC_BIKE.conf │ ├── PEMS07M.conf │ └── chengdu_didi.conf └── TGCN │ ├── CA_District5.conf │ ├── NYC_BIKE.conf │ ├── PEMS07M.conf │ └── chengdu_didi.conf ├── lib ├── TrainInits.py ├── data_process.py ├── logger.py ├── metrics.py └── predifineGraph.py ├── model ├── AGCRN │ ├── AGCN.py │ ├── AGCRN.py │ ├── AGCRNCell.py │ └── args.py ├── ASTGCN │ ├── ASTGCN.py │ └── args.py ├── DMSTGCN │ └── DMSTGCN.py ├── FlashST.py ├── GWN │ ├── GWN.py │ └── args.py ├── MSDR │ ├── args.py │ ├── gmsdr_cell.py │ └── gmsdr_model.py ├── MTGNN │ ├── MTGNN.py │ └── args.py ├── PDFormer │ ├── PDFformer.py │ └── args.py ├── PromptNet.py ├── Run.py ├── STFGNN │ ├── STFGNN.py │ └── args.py ├── STGCN │ ├── args.py │ └── stgcn.py ├── STGODE │ ├── STGODE.py │ ├── args.py │ └── odegcn.py ├── STSGCN │ ├── STSGCN.py │ └── args.py ├── ST_WA │ ├── ST_WA.py │ ├── args.py │ └── attention.py ├── TGCN │ ├── TGCN.py │ └── args.py └── Trainer.py └── requirements.txt /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlashST: A Simple and Universal Prompt-Tuning Framework for Traffic Prediction 2 | 3 | 4 | 5 | A pytorch implementation for the paper: [FlashST: A Simple and Universal Prompt-Tuning Framework for Traffic Prediction](https://arxiv.org/abs/2405.17898)
6 | 7 | [Zhonghang Li](https://scholar.google.com/citations?user=__9uvQkAAAAJ), [Lianghao Xia](https://akaxlh.github.io/), [Yong Xu](https://scholar.google.com/citations?user=1hx5iwEAAAAJ), [Chao Huang](https://sites.google.com/view/chaoh)* (*Correspondence)
8 | 9 | **[Data Intelligence Lab](https://sites.google.com/view/chaoh/home)@[University of Hong Kong](https://www.hku.hk/)**, [South China University of Technology](https://www.scut.edu.cn/en/), PAZHOU LAB 10 | 11 | ----------- 12 | 13 | ## Introduction 14 | 15 |

16 | In this work, we introduce a simple and universal spatio-temporal prompt-tuning framework, which addresses the significant challenge posed by distribution shift in this field. 17 | To achieve this objective, we present FlashST, a framework that adapts pretrained models to the specific characteristics of diverse downstream datasets, thereby improving generalization across various prediction scenarios. 18 | We begin by utilizing a lightweight spatio-temporal prompt network for in-context learning, capturing spatio-temporal invariant knowledge and facilitating effective adaptation to diverse scenarios. Additionally, we incorporate a distribution mapping mechanism to align the data distributions of pre-training and downstream data, facilitating effective knowledge transfer in spatio-temporal forecasting. Empirical evaluations demonstrate the effectiveness of our FlashST across different spatio-temporal prediction tasks. 19 | 20 |

21 | 22 | ![The detailed framework of the proposed FlashST.](https://github.com/LZH-YS1998/GPT-ST_img/blob/main/FlashST.png) 23 | 24 | ----------- 25 | 26 | 27 | 28 | 29 | ## Getting Started 30 | 31 | 32 | 33 | ### Table of Contents: 34 | * 1. Code Structure 35 | * 2. Environment 36 | * 3. Run the codes 37 | 38 | **** 39 | 40 | 41 | 42 | 43 | ### 1. Code Structure [Back to Top] 44 | 45 | 46 | * **conf**: This folder includes parameter settings for FlashST (`config.conf`) as well as all other baseline models. 47 | * **data**: The documentation encompasses all the datasets utilized in our work, alongside prefabricated files and the corresponding file generation codes necessary for certain baselines. 48 | * **lib**: Including a series of initialization methods for data processing, as follows: 49 | * `data_process.py`: Load, split, generate data, normalization method, slicing, etc. 50 | * `logger.py`: For output printing. 51 | * `metrics.py`: Method for calculating evaluation indicators. 52 | * `predifineGraph.py`: Predefined graph generation method. 53 | * `TrainInits.py`: Training initialization, including settings of optimizer, device, random seed, etc. 54 | * **model**: Includes the implementation of FlashST and all baseline models, along with the necessary code to support the framework's execution. The `args.py` script is utilized to generate the required prefabricated data and parameter configurations for different baselines. Additionally, the `SAVE` folder serves as the storage location for saving the pre-trained models. 55 | * **SAVE**: This folder serves as the storage location for saving the trained models, including pretrain, eval and ori. 56 | 57 | 58 | ``` 59 | │ README.md 60 | │ requirements.txt 61 | │ 62 | ├─conf 63 | │ ├─AGCRN 64 | │ ├─ASTGCN 65 | │ ├─FlashST 66 | │ │ │ config.conf 67 | │ │ │ Params_pretrain.py 68 | │ ├─GWN 69 | │ ├─MSDR 70 | │ ├─MTGNN 71 | │ ├─PDFormer 72 | │ ├─ST-WA 73 | │ ├─STFGNN 74 | │ ├─STGCN 75 | │ ├─STSGCN 76 | │ └─TGCN 77 | │ 78 | ├─data 79 | │ ├─CA_District5 80 | │ ├─chengdu_didi 81 | │ ├─NYC_BIKE 82 | │ ├─PEMS03 83 | │ ├─PEMS04 84 | │ ├─PEMS07 85 | │ ├─PEMS07M 86 | │ ├─PEMS08 87 | │ ├─PDFormer 88 | │ ├─STFGNN 89 | │ └─STGODE 90 | │ 91 | ├─lib 92 | │ │ data_process.py 93 | │ │ logger.py 94 | │ │ metrics.py 95 | │ │ predifineGraph.py 96 | │ │ TrainInits.py 97 | │ 98 | ├─model 99 | │ │ FlashST.py 100 | │ │ PromptNet.py 101 | │ │ Run.py 102 | │ │ Trainer.py 103 | │ │ 104 | │ ├─AGCRN 105 | │ ├─ASTGCN 106 | │ ├─DMSTGCN 107 | │ ├─GWN 108 | │ ├─MSDR 109 | │ ├─MTGNN 110 | │ ├─PDFormer 111 | │ ├─STFGNN 112 | │ ├─STGCN 113 | │ ├─STGODE 114 | │ ├─STSGCN 115 | │ ├─ST_WA 116 | │ └─TGCN 117 | │ 118 | └─SAVE 119 | └─pretrain 120 | ├─GWN 121 | │ GWN_P8437.pth 122 | │ 123 | ├─MTGNN 124 | │ MTGNN_P8437.pth 125 | │ 126 | ├─PDFormer 127 | │ PDFormer_P8437.pth 128 | │ 129 | └─STGCN 130 | STGCN_P8437.pth 131 | 132 | ``` 133 | 134 | --------- 135 | 136 | 137 | 138 | ### 2.Environment [Back to Top] 139 | The code can be run in the following environments, other version of required packages may also work. 140 | * python==3.9.12 141 | * numpy==1.23.1 142 | * pytorch==1.9.0 143 | * cudatoolkit==11.1.1 144 | 145 | Or you can install the required environment, which can be done by running the following commands: 146 | ``` 147 | # cteate new environmrnt 148 | conda create -n FlashST python=3.9.12 149 | 150 | # activate environmrnt 151 | conda activate FlashST 152 | 153 | # Torch with CUDA 11.1 154 | pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html 155 | 156 | # Install required libraries 157 | pip install -r requirements.txt 158 | ``` 159 | 160 | --------- 161 | 162 | 163 | 164 | ### 3. Run the codes [Back to Top] 165 | 166 | * First, download "data" folder from [hugging face (data.zip)](https://huggingface.co/datasets/bjdwh/FlashST-DATA/tree/main) or [wisemodel (data.zip)](https://wisemodel.cn/datasets/BJDWH/FlashST-data/file), put it in the FlashST-main directory, unzip and then enter "model" folder: 167 | ``` 168 | cd model 169 | ``` 170 | * To test different models in various modes, you can execute the Run.py code. There are some examples: 171 | ``` 172 | # Evaluate the performance of MTGNN enhanced by FlashST on the PEMS07M dataset 173 | python Run.py -dataset_test PEMS07M -mode eval -model MTGNN 174 | 175 | # Evaluate the performance of STGCN enhanced by FlashST on the CA_District5 dataset 176 | python Run.py -dataset_test CA_District5 -mode eval -model STGCN 177 | 178 | # Evaluate the original performance of STGCN on the chengdu_didi dataset 179 | python Run.py -dataset_test chengdu_didi -mode ori -model STGCN 180 | 181 | # Pretrain from scratch with MTGNN model, checkpoint will be saved in FlashST-main/SAVE/pretrain/MTGNN(model name)/xxx.pth 182 | python Run.py -mode pretrain -model MTGNN 183 | ``` 184 | 185 | * Parameter setting instructions. The parameter settings consist of two parts: the pre-training model and the baseline model. To avoid any confusion arising from potential overlapping parameter names, we employ a hyphen (-) to specify the parameters of FlashST and use a double hyphen (--) to specify the parameters of the baseline model. Here is an example: 186 | ``` 187 | # Set first_layer_embedding_size and out_layer_dim to 32 in STFGNN 188 | python Run.py -model STFGNN -mode ori -dataset_test PEMS08 --first_layer_embedding_size 32 --out_layer_dim 32 189 | ``` 190 | 191 | --------- 192 | 193 | 194 | ## Citation 195 | 196 | If you find FlashST useful in your research or applications, please kindly cite: 197 | 198 | ``` 199 | @misc{li2024flashst, 200 | title={FlashST: A Simple and Universal Prompt-Tuning Framework for Traffic Prediction}, 201 | author={Zhonghang Li and Lianghao Xia and Yong Xu and Chao Huang}, 202 | year={2024}, 203 | eprint={2405.17898}, 204 | archivePrefix={arXiv}, 205 | primaryClass={cs.LG} 206 | } 207 | ``` 208 | --------- 209 | 210 | 211 | ## Acknowledgements 212 | We developed our code framework drawing inspiration from [AGCRN](https://github.com/LeiBAI/AGCRN) and [GPT-ST](https://github.com/HKUDS/GPT-ST). Furthermore, the implementation of the baselines primarily relies on a combination of the code released by the original author and the code from [LibCity](https://github.com/LibCity/Bigscity-LibCity). We extend our heartfelt gratitude for their remarkable contribution. 213 | -------------------------------------------------------------------------------- /SAVE/pretrain/GWN/GWN_P8437.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/FlashST/39d4b8e445fde7ed5272151f7e74752d2c84c58e/SAVE/pretrain/GWN/GWN_P8437.pth -------------------------------------------------------------------------------- /SAVE/pretrain/MTGNN/MTGNN_P8437.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/FlashST/39d4b8e445fde7ed5272151f7e74752d2c84c58e/SAVE/pretrain/MTGNN/MTGNN_P8437.pth -------------------------------------------------------------------------------- /SAVE/pretrain/PDFormer/PDFormer_P8437.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/FlashST/39d4b8e445fde7ed5272151f7e74752d2c84c58e/SAVE/pretrain/PDFormer/PDFormer_P8437.pth -------------------------------------------------------------------------------- /SAVE/pretrain/STGCN/STGCN_P8437.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUDS/FlashST/39d4b8e445fde7ed5272151f7e74752d2c84c58e/SAVE/pretrain/STGCN/STGCN_P8437.pth -------------------------------------------------------------------------------- /conf/AGCRN/CA_District5.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 211 3 | lag = 12 4 | horizon = 12 5 | val_ratio = 0.1 6 | test_ratio = 0.2 7 | tod = False 8 | normalizer = std 9 | column_wise = False 10 | default_graph = True 11 | 12 | [model] 13 | input_dim = 1 14 | output_dim = 1 15 | embed_dim = 2 16 | rnn_units = 64 17 | num_layers = 2 18 | cheb_order = 2 19 | 20 | [train] 21 | loss_func = mae 22 | seed = 10 23 | batch_size = 64 24 | epochs = 100 25 | lr_init = 0.003 26 | lr_decay = False 27 | lr_decay_rate = 0.3 28 | lr_decay_step = 5,20,40,70 29 | early_stop = True 30 | early_stop_patience = 15 31 | grad_norm = False 32 | max_grad_norm = 5 33 | real_value = True 34 | 35 | [test] 36 | mae_thresh = None 37 | mape_thresh = 0. 38 | 39 | [log] 40 | log_step = 20 41 | plot = False -------------------------------------------------------------------------------- /conf/AGCRN/NYC_BIKE.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 250 3 | lag = 12 4 | horizon = 12 5 | val_ratio = 0.2 6 | test_ratio = 0.2 7 | tod = False 8 | normalizer = std 9 | column_wise = False 10 | default_graph = True 11 | 12 | [model] 13 | input_dim = 2 14 | output_dim = 2 15 | embed_dim = 2 16 | rnn_units = 64 17 | num_layers = 2 18 | cheb_order = 2 19 | 20 | [train] 21 | loss_func = mae 22 | seed = 12 23 | batch_size = 64 24 | epochs = 100 25 | 26 | -------------------------------------------------------------------------------- /conf/AGCRN/PEMS07M.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 228 3 | lag = 12 4 | horizon = 12 5 | val_ratio = 0.1 6 | test_ratio = 0.2 7 | tod = False 8 | normalizer = std 9 | column_wise = False 10 | default_graph = True 11 | 12 | [model] 13 | input_dim = 1 14 | output_dim = 1 15 | embed_dim = 2 16 | rnn_units = 64 17 | num_layers = 2 18 | cheb_order = 2 19 | 20 | [train] 21 | loss_func = mae 22 | seed = 10 23 | batch_size = 64 24 | epochs = 100 25 | lr_init = 0.003 26 | lr_decay = False 27 | lr_decay_rate = 0.3 28 | lr_decay_step = 5,20,40,70 29 | early_stop = True 30 | early_stop_patience = 15 31 | grad_norm = False 32 | max_grad_norm = 5 33 | real_value = True 34 | 35 | [test] 36 | mae_thresh = None 37 | mape_thresh = 0. 38 | 39 | [log] 40 | log_step = 20 41 | plot = False -------------------------------------------------------------------------------- /conf/AGCRN/chengdu_didi.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 524 3 | lag = 12 4 | horizon = 12 5 | val_ratio = 0.1 6 | test_ratio = 0.2 7 | tod = False 8 | normalizer = std 9 | column_wise = False 10 | default_graph = True 11 | 12 | [model] 13 | input_dim = 1 14 | output_dim = 1 15 | embed_dim = 2 16 | rnn_units = 64 17 | num_layers = 2 18 | cheb_order = 2 19 | 20 | [train] 21 | loss_func = mae 22 | seed = 10 23 | batch_size = 64 24 | epochs = 100 25 | lr_init = 0.003 26 | lr_decay = False 27 | lr_decay_rate = 0.3 28 | lr_decay_step = 5,20,40,70 29 | early_stop = True 30 | early_stop_patience = 15 31 | grad_norm = False 32 | max_grad_norm = 5 33 | real_value = True 34 | 35 | [test] 36 | mae_thresh = None 37 | mape_thresh = 0. 38 | 39 | [log] 40 | log_step = 20 41 | plot = False -------------------------------------------------------------------------------- /conf/ASTGCN/CA_District5.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 211 3 | len_input = 12 4 | num_for_predict = 12 5 | 6 | [model] 7 | nb_block = 2 8 | K = 3 9 | nb_chev_filter = 64 10 | nb_time_filter = 64 11 | time_strides = 1 -------------------------------------------------------------------------------- /conf/ASTGCN/NYC_BIKE.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 250 3 | len_input = 12 4 | num_for_predict = 12 5 | 6 | [model] 7 | nb_block = 2 8 | K = 3 9 | nb_chev_filter = 64 10 | nb_time_filter = 64 11 | time_strides = 1 -------------------------------------------------------------------------------- /conf/ASTGCN/PEMS07M.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 228 3 | len_input = 12 4 | num_for_predict = 12 5 | 6 | [model] 7 | nb_block = 2 8 | K = 3 9 | nb_chev_filter = 64 10 | nb_time_filter = 64 11 | time_strides = 1 -------------------------------------------------------------------------------- /conf/ASTGCN/chengdu_didi.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 524 3 | len_input = 12 4 | num_for_predict = 12 5 | 6 | [model] 7 | nb_block = 2 8 | K = 3 9 | nb_chev_filter = 64 10 | nb_time_filter = 64 11 | time_strides = 1 -------------------------------------------------------------------------------- /conf/FlashST/Params_pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | # import numpy as np 3 | import configparser 4 | # import pandas as pd 5 | 6 | def parse_args(device): 7 | # parser 8 | args = argparse.ArgumentParser(prefix_chars='-', description='pretrain_arguments') 9 | args_get, _ = args.parse_known_args() 10 | # get configuration 11 | config_file = '../conf/FlashST/config.conf' 12 | config = configparser.ConfigParser() 13 | config.read(config_file) 14 | 15 | # parser 16 | args = argparse.ArgumentParser(prefix_chars='-', description='arguments') 17 | 18 | args.add_argument('-cuda', default=True, type=bool) 19 | args.add_argument('-device', default=device, type=str, help='indices of GPUs') 20 | args.add_argument('-mode', default='ori', type=str, required=True) 21 | args.add_argument('-model', default='STGCN', type=str) 22 | args.add_argument('-dataset_test', default='PEMS07M', type=str) 23 | args.add_argument('-dataset_use', default=config['data']['dataset_use'].split(',')) 24 | 25 | # data 26 | args.add_argument('-his', default=config['data']['his'], type=int) 27 | args.add_argument('-pred', default=config['data']['pred'], type=int) 28 | args.add_argument('-val_ratio', default=config['data']['val_ratio'], type=float) 29 | args.add_argument('-test_ratio', default=config['data']['test_ratio'], type=float) 30 | args.add_argument('-tod', default=config['data']['tod'], type=eval) 31 | args.add_argument('-normalizer', default=config['data']['normalizer'], type=str) 32 | args.add_argument('-column_wise', default=config['data']['column_wise'], type=eval) 33 | args.add_argument('-default_graph', default=config['data']['default_graph'], type=eval) 34 | # model 35 | args.add_argument('-input_base_dim', default=config['model']['input_base_dim'], type=int) 36 | args.add_argument('-input_extra_dim', default=config['model']['input_extra_dim'], type=int) 37 | args.add_argument('-output_dim', default=config['model']['output_dim'], type=int) 38 | args.add_argument('-node_dim', default=config['model']['node_dim'], type=int) 39 | args.add_argument('-embed_dim', default=config['model']['embed_dim'], type=int) 40 | args.add_argument('-num_layer', default=config['model']['num_layer'], type=int) 41 | args.add_argument('-temp_dim_tid', default=config['model']['temp_dim_tid'], type=int) 42 | args.add_argument('-temp_dim_diw', default=config['model']['temp_dim_diw'], type=int) 43 | args.add_argument('-use_lpls', default=config['model']['use_lpls'], type=eval) 44 | args.add_argument('-if_time_in_day', default=config['model']['if_time_in_day'], type=eval) 45 | args.add_argument('-if_day_in_week', default=config['model']['if_day_in_week'], type=eval) 46 | args.add_argument('-if_spatial', default=config['model']['if_spatial'], type=eval) 47 | # train 48 | # args.add_argument('-mode', default=config['train']['mode'], type=str) 49 | args.add_argument('-loss_func', default=config['train']['loss_func'], type=str) 50 | args.add_argument('-seed', default=config['train']['seed'], type=int) 51 | args.add_argument('-batch_size', default=config['train']['batch_size'], type=int) 52 | args.add_argument('-lr_init', default=config['train']['lr_init'], type=float) 53 | args.add_argument('-lr_decay', default=config['train']['lr_decay'], type=eval) 54 | args.add_argument('-lr_decay_rate', default=config['train']['lr_decay_rate'], type=float) 55 | args.add_argument('-lr_decay_step', default=config['train']['lr_decay_step'], type=str) 56 | args.add_argument('-early_stop', default=config['train']['early_stop'], type=eval) 57 | args.add_argument('-early_stop_patience', default=config['train']['early_stop_patience'], type=int) 58 | args.add_argument('-grad_norm', default=config['train']['grad_norm'], type=eval) 59 | args.add_argument('-max_grad_norm', default=config['train']['max_grad_norm'], type=int) 60 | args.add_argument('-real_value', default=config['train']['real_value'], type=eval, 61 | help='use real value for loss calculation') 62 | args.add_argument('-pretrain_epochs', default=config['train']['pretrain_epochs'], type=int) 63 | args.add_argument('-eval_epochs', default=config['train']['eval_epochs'], type=int) 64 | args.add_argument('-ori_epochs', default=config['train']['ori_epochs'], type=int) 65 | args.add_argument('-load_pretrain_path', default=config['train']['load_pretrain_path'], type=str) 66 | args.add_argument('-save_pretrain_path', default=config['train']['save_pretrain_path'], type=str) 67 | args.add_argument('-debug', default=config['train']['debug'], type=str) 68 | # test 69 | args.add_argument('-mae_thresh', default=config['test']['mae_thresh'], type=eval) 70 | args.add_argument('-mape_thresh', default=config['test']['mape_thresh'], type=float) 71 | # log 72 | args.add_argument('-log_dir', default='./', type=str) 73 | args.add_argument('-log_step', default=config['log']['log_step'], type=int) 74 | args.add_argument('-plot', default=config['log']['plot'], type=eval) 75 | args, _ = args.parse_known_args() 76 | return args -------------------------------------------------------------------------------- /conf/FlashST/config.conf: -------------------------------------------------------------------------------- 1 | [new_para] 2 | # dataset_use = ['PEMS08', 'PEMS04', 'PEMS07', 'PEMS03'] 3 | [data] 4 | dataset_use = PEMS08,PEMS04,PEMS07,PEMS03 5 | # dataset_use = PEMS08,PEMS04 6 | dataset_test = PEMS07M # NYC_BIKE, CA_District5, PEMS07M, chengdu_didi 7 | his = 12 8 | pred = 12 9 | val_ratio = 0.2 10 | test_ratio = 0.2 11 | tod = False 12 | normalizer = std 13 | column_wise = False 14 | default_graph = True 15 | 16 | [model] 17 | input_base_dim = 1 18 | input_extra_dim = 2 19 | output_dim = 1 20 | use_lpls = False 21 | node_dim = 32 22 | embed_dim = 32 23 | num_layer = 3 24 | temp_dim_tid = 32 25 | temp_dim_diw = 32 26 | if_time_in_day = True 27 | if_day_in_week = True 28 | if_spatial = True 29 | 30 | 31 | [train] 32 | loss_func = mask_mae 33 | seed = 0 34 | batch_size = 64 35 | lr_init = 0.003 36 | lr_decay = True 37 | lr_decay_rate = 0.3 38 | lr_decay_step = 70, 160, 240 39 | early_stop = True 40 | early_stop_patience = 25 41 | grad_norm = True 42 | max_grad_norm = 5 43 | real_value = False 44 | pretrain_epochs = 300 45 | eval_epochs = 20 46 | ori_epochs = 100 47 | load_pretrain_path = GWN_P8437.pth 48 | save_pretrain_path = P8437_stgcn.pth 49 | debug = True 50 | 51 | [test] 52 | mae_thresh = 0. 53 | mape_thresh = 0.001 54 | 55 | 56 | [log] 57 | log_step = 20 58 | plot = False -------------------------------------------------------------------------------- /conf/GWN/CA_District5.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 211 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 1 9 | 10 | [model] 11 | dropout = 0.3 12 | blocks = 4 13 | layers = 2 14 | gcn_bool = True 15 | addaptadj = True 16 | adjtype = doubletransition 17 | randomadj = True 18 | aptonly = True 19 | kernel_size = 2 20 | nhid = 32 21 | residual_channels = 32 22 | dilation_channels = 32 -------------------------------------------------------------------------------- /conf/GWN/NYC_BIKE.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 250 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 1 9 | 10 | [model] 11 | dropout = 0.3 12 | blocks = 4 13 | layers = 2 14 | gcn_bool = True 15 | addaptadj = True 16 | adjtype = doubletransition 17 | randomadj = True 18 | aptonly = True 19 | kernel_size = 2 20 | nhid = 32 21 | residual_channels = 32 22 | dilation_channels = 32 -------------------------------------------------------------------------------- /conf/GWN/PEMS03.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 358 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 1 9 | 10 | [model] 11 | dropout = 0.3 12 | blocks = 4 13 | layers = 2 14 | gcn_bool = True 15 | addaptadj = True 16 | adjtype = doubletransition 17 | randomadj = True 18 | aptonly = True 19 | kernel_size = 2 20 | nhid = 32 21 | residual_channels = 32 22 | dilation_channels = 32 -------------------------------------------------------------------------------- /conf/GWN/PEMS04.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 307 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 1 9 | 10 | [model] 11 | dropout = 0.3 12 | blocks = 4 13 | layers = 2 14 | gcn_bool = True 15 | addaptadj = True 16 | adjtype = doubletransition 17 | randomadj = True 18 | aptonly = True 19 | kernel_size = 2 20 | nhid = 32 21 | residual_channels = 32 22 | dilation_channels = 32 -------------------------------------------------------------------------------- /conf/GWN/PEMS07.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 883 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 1 9 | 10 | [model] 11 | dropout = 0.3 12 | blocks = 4 13 | layers = 2 14 | gcn_bool = True 15 | addaptadj = True 16 | adjtype = doubletransition 17 | randomadj = True 18 | aptonly = True 19 | kernel_size = 2 20 | nhid = 32 21 | residual_channels = 32 22 | dilation_channels = 32 -------------------------------------------------------------------------------- /conf/GWN/PEMS07M.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 228 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 1 9 | 10 | [model] 11 | dropout = 0.3 12 | blocks = 4 13 | layers = 2 14 | gcn_bool = True 15 | addaptadj = True 16 | adjtype = doubletransition 17 | randomadj = True 18 | aptonly = True 19 | kernel_size = 2 20 | nhid = 32 21 | residual_channels = 32 22 | dilation_channels = 32 -------------------------------------------------------------------------------- /conf/GWN/PEMS08.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 170 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 1 9 | 10 | [model] 11 | dropout = 0.3 12 | blocks = 4 13 | layers = 2 14 | gcn_bool = True 15 | addaptadj = True 16 | adjtype = doubletransition 17 | randomadj = True 18 | aptonly = True 19 | kernel_size = 2 20 | nhid = 32 21 | residual_channels = 32 22 | dilation_channels = 32 -------------------------------------------------------------------------------- /conf/GWN/chengdu_didi.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 524 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 1 9 | 10 | [model] 11 | dropout = 0.3 12 | blocks = 4 13 | layers = 2 14 | gcn_bool = True 15 | addaptadj = True 16 | adjtype = doubletransition 17 | randomadj = True 18 | aptonly = True 19 | kernel_size = 2 20 | nhid = 32 21 | residual_channels = 32 22 | dilation_channels = 32 -------------------------------------------------------------------------------- /conf/MSDR/CA_District5.conf: -------------------------------------------------------------------------------- 1 | 2 | [data] 3 | # batch_size = 64 4 | # data = ./data/processed/PEMS08/ 5 | # sensors_distance = ./data/PEMS08/PEMS08.csv 6 | # column_wise = False 7 | # normalizer = std 8 | 9 | [model] 10 | cl_decay_steps = 2000 11 | filter_type = dual_random_walk 12 | horizon = 12 13 | input_dim = 1 14 | max_diffusion_step = 1 15 | num_nodes = 211 16 | num_rnn_layers = 2 17 | output_dim = 1 18 | rnn_units = 64 19 | seq_len = 12 20 | pre_k = 4 21 | pre_v = 1 22 | use_curriculum_learning = True 23 | construct_type = connectivity 24 | l2lambda = 0 25 | 26 | [train] 27 | # base_lr = 0.0015 28 | dropout = 0 29 | # epoch = 0 30 | # epochs = 250 31 | # epsilon = 1.0e-3 32 | # global_step = 0 33 | # lr_decay_ratio = 0.2 34 | # max_grad_norm = 5 35 | # max_to_keep = 100 36 | # min_learning_rate = 2.0e-06 37 | # optimizer = adam 38 | # patience = 50 39 | # steps = [30, 50, 70, 80] 40 | # test_every_n_epochs = 10 -------------------------------------------------------------------------------- /conf/MSDR/NYC_BIKE.conf: -------------------------------------------------------------------------------- 1 | 2 | [data] 3 | # batch_size = 64 4 | # data = ./data/processed/PEMS08/ 5 | # sensors_distance = ./data/PEMS08/PEMS08.csv 6 | # column_wise = False 7 | # normalizer = std 8 | 9 | [model] 10 | cl_decay_steps = 2000 11 | filter_type = dual_random_walk 12 | horizon = 12 13 | input_dim = 1 14 | max_diffusion_step = 1 15 | num_nodes = 250 16 | num_rnn_layers = 2 17 | output_dim = 1 18 | rnn_units = 64 19 | seq_len = 12 20 | pre_k = 4 21 | pre_v = 1 22 | use_curriculum_learning = True 23 | construct_type = connectivity 24 | l2lambda = 0 25 | 26 | [train] 27 | # base_lr = 0.0015 28 | dropout = 0 29 | # epoch = 0 30 | # epochs = 250 31 | # epsilon = 1.0e-3 32 | # global_step = 0 33 | # lr_decay_ratio = 0.2 34 | # max_grad_norm = 5 35 | # max_to_keep = 100 36 | # min_learning_rate = 2.0e-06 37 | # optimizer = adam 38 | # patience = 50 39 | # steps = [30, 50, 70, 80] 40 | # test_every_n_epochs = 10 -------------------------------------------------------------------------------- /conf/MSDR/PEMS07M.conf: -------------------------------------------------------------------------------- 1 | 2 | [data] 3 | # batch_size = 64 4 | # data = ./data/processed/PEMS08/ 5 | # sensors_distance = ./data/PEMS08/PEMS08.csv 6 | # column_wise = False 7 | # normalizer = std 8 | 9 | [model] 10 | cl_decay_steps = 2000 11 | filter_type = dual_random_walk 12 | horizon = 12 13 | input_dim = 1 14 | max_diffusion_step = 1 15 | num_nodes = 228 16 | num_rnn_layers = 2 17 | output_dim = 1 18 | rnn_units = 64 19 | seq_len = 12 20 | pre_k = 4 21 | pre_v = 1 22 | use_curriculum_learning = True 23 | construct_type = connectivity 24 | l2lambda = 0 25 | 26 | [train] 27 | # base_lr = 0.0015 28 | dropout = 0 29 | # epoch = 0 30 | # epochs = 250 31 | # epsilon = 1.0e-3 32 | # global_step = 0 33 | # lr_decay_ratio = 0.2 34 | # max_grad_norm = 5 35 | # max_to_keep = 100 36 | # min_learning_rate = 2.0e-06 37 | # optimizer = adam 38 | # patience = 50 39 | # steps = [30, 50, 70, 80] 40 | # test_every_n_epochs = 10 -------------------------------------------------------------------------------- /conf/MSDR/chengdu_didi.conf: -------------------------------------------------------------------------------- 1 | 2 | [data] 3 | # batch_size = 64 4 | # data = ./data/processed/PEMS08/ 5 | # sensors_distance = ./data/PEMS08/PEMS08.csv 6 | # column_wise = False 7 | # normalizer = std 8 | 9 | [model] 10 | cl_decay_steps = 2000 11 | filter_type = dual_random_walk 12 | horizon = 12 13 | input_dim = 1 14 | max_diffusion_step = 1 15 | num_nodes = 524 16 | num_rnn_layers = 2 17 | output_dim = 1 18 | rnn_units = 64 19 | seq_len = 12 20 | pre_k = 4 21 | pre_v = 1 22 | use_curriculum_learning = True 23 | construct_type = connectivity 24 | l2lambda = 0 25 | 26 | [train] 27 | # base_lr = 0.0015 28 | dropout = 0 29 | # epoch = 0 30 | # epochs = 250 31 | # epsilon = 1.0e-3 32 | # global_step = 0 33 | # lr_decay_ratio = 0.2 34 | # max_grad_norm = 5 35 | # max_to_keep = 100 36 | # min_learning_rate = 2.0e-06 37 | # optimizer = adam 38 | # patience = 50 39 | # steps = [30, 50, 70, 80] 40 | # test_every_n_epochs = 10 -------------------------------------------------------------------------------- /conf/MTGNN/CA_District5.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 211 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 1 9 | 10 | [model] 11 | gcn_true = True 12 | buildA_true = True 13 | gcn_depth = 2 14 | dropout = 0.3 15 | subgraph_size = 20 16 | node_dim = 40 17 | dilation_exponential = 1 18 | conv_channels = 32 19 | residual_channels = 32 20 | skip_channels = 64 21 | end_channels = 128 22 | layers = 3 23 | propalpha = 0.05 24 | tanhalpha = 3 25 | layer_norm_affline = True 26 | use_curriculum_learning = True 27 | step_size1 = 2500 28 | task_level = 0 29 | num_split = 1 30 | step_size2 = 100 -------------------------------------------------------------------------------- /conf/MTGNN/NYC_BIKE.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 250 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 2 9 | 10 | [model] 11 | gcn_true = True 12 | buildA_true = True 13 | gcn_depth = 2 14 | dropout = 0.3 15 | subgraph_size = 20 16 | node_dim = 40 17 | dilation_exponential = 1 18 | conv_channels = 32 19 | residual_channels = 32 20 | skip_channels = 64 21 | end_channels = 128 22 | layers = 3 23 | propalpha = 0.05 24 | tanhalpha = 3 25 | layer_norm_affline = True 26 | use_curriculum_learning = True 27 | step_size1 = 2500 28 | task_level = 0 29 | num_split = 1 30 | step_size2 = 100 -------------------------------------------------------------------------------- /conf/MTGNN/PEMS03.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 358 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 1 9 | 10 | [model] 11 | gcn_true = True 12 | buildA_true = True 13 | gcn_depth = 2 14 | dropout = 0.3 15 | subgraph_size = 20 16 | node_dim = 40 17 | dilation_exponential = 1 18 | conv_channels = 32 19 | residual_channels = 32 20 | skip_channels = 64 21 | end_channels = 128 22 | layers = 3 23 | propalpha = 0.05 24 | tanhalpha = 3 25 | layer_norm_affline = True 26 | use_curriculum_learning = True 27 | step_size1 = 2500 28 | task_level = 0 29 | num_split = 1 30 | step_size2 = 100 -------------------------------------------------------------------------------- /conf/MTGNN/PEMS04.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 307 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 1 9 | 10 | [model] 11 | gcn_true = True 12 | buildA_true = True 13 | gcn_depth = 2 14 | dropout = 0.3 15 | subgraph_size = 20 16 | node_dim = 40 17 | dilation_exponential = 1 18 | conv_channels = 32 19 | residual_channels = 32 20 | skip_channels = 64 21 | end_channels = 128 22 | layers = 3 23 | propalpha = 0.05 24 | tanhalpha = 3 25 | layer_norm_affline = True 26 | use_curriculum_learning = True 27 | step_size1 = 2500 28 | task_level = 0 29 | num_split = 1 30 | step_size2 = 100 -------------------------------------------------------------------------------- /conf/MTGNN/PEMS07.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 883 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 1 9 | 10 | [model] 11 | gcn_true = True 12 | buildA_true = True 13 | gcn_depth = 2 14 | dropout = 0.3 15 | subgraph_size = 20 16 | node_dim = 40 17 | dilation_exponential = 1 18 | conv_channels = 32 19 | residual_channels = 32 20 | skip_channels = 64 21 | end_channels = 128 22 | layers = 3 23 | propalpha = 0.05 24 | tanhalpha = 3 25 | layer_norm_affline = True 26 | use_curriculum_learning = True 27 | step_size1 = 2500 28 | task_level = 0 29 | num_split = 1 30 | step_size2 = 100 -------------------------------------------------------------------------------- /conf/MTGNN/PEMS07M.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 228 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 1 9 | 10 | [model] 11 | gcn_true = True 12 | buildA_true = True 13 | gcn_depth = 2 14 | dropout = 0.3 15 | subgraph_size = 20 16 | node_dim = 40 17 | dilation_exponential = 1 18 | conv_channels = 32 19 | residual_channels = 32 20 | skip_channels = 64 21 | end_channels = 128 22 | layers = 3 23 | propalpha = 0.05 24 | tanhalpha = 3 25 | layer_norm_affline = True 26 | use_curriculum_learning = True 27 | step_size1 = 2500 28 | task_level = 0 29 | num_split = 1 30 | step_size2 = 100 -------------------------------------------------------------------------------- /conf/MTGNN/PEMS08.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 170 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 1 9 | 10 | [model] 11 | gcn_true = True 12 | buildA_true = True 13 | gcn_depth = 2 14 | dropout = 0.3 15 | subgraph_size = 20 16 | node_dim = 40 17 | dilation_exponential = 1 18 | conv_channels = 32 19 | residual_channels = 32 20 | skip_channels = 64 21 | end_channels = 128 22 | layers = 3 23 | propalpha = 0.05 24 | tanhalpha = 3 25 | layer_norm_affline = True 26 | use_curriculum_learning = True 27 | step_size1 = 2500 28 | task_level = 0 29 | num_split = 1 30 | step_size2 = 100 -------------------------------------------------------------------------------- /conf/MTGNN/chengdu_didi.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 524 6 | input_window = 12 7 | output_window = 12 8 | output_dim = 1 9 | 10 | [model] 11 | gcn_true = True 12 | buildA_true = True 13 | gcn_depth = 2 14 | dropout = 0.3 15 | subgraph_size = 20 16 | node_dim = 40 17 | dilation_exponential = 1 18 | conv_channels = 32 19 | residual_channels = 32 20 | skip_channels = 64 21 | end_channels = 128 22 | layers = 3 23 | propalpha = 0.05 24 | tanhalpha = 3 25 | layer_norm_affline = True 26 | use_curriculum_learning = True 27 | step_size1 = 2500 28 | task_level = 0 29 | num_split = 1 30 | step_size2 = 100 -------------------------------------------------------------------------------- /conf/PDFormer/CA_District5.conf: -------------------------------------------------------------------------------- 1 | [model] 2 | embed_dim = 64 3 | skip_dim = 256 4 | lape_dim = 32 5 | geo_num_heads = 4 6 | sem_num_heads = 2 7 | t_num_heads = 2 8 | mlp_ratio = 4 9 | qkv_bias = True 10 | drop = 0 11 | attn_drop = 0 12 | drop_path = 0.3 13 | s_attn_size = 3 14 | t_attn_size = 1 15 | enc_depth = 6 16 | type_ln = post 17 | type_short_path = hop 18 | add_time_in_day = False 19 | add_day_in_week = False 20 | far_mask_delta=5 21 | dtw_delta=5 22 | time_intervals=300 23 | cand_key_days=14 24 | n_cluster=16 25 | cluster_max_iter=5 26 | cluster_method=kshape -------------------------------------------------------------------------------- /conf/PDFormer/NYC_BIKE.conf: -------------------------------------------------------------------------------- 1 | [model] 2 | embed_dim = 64 3 | skip_dim = 256 4 | lape_dim = 32 5 | geo_num_heads = 4 6 | sem_num_heads = 2 7 | t_num_heads = 2 8 | mlp_ratio = 4 9 | qkv_bias = True 10 | drop = 0 11 | attn_drop = 0 12 | drop_path = 0.3 13 | s_attn_size = 3 14 | t_attn_size = 1 15 | enc_depth = 6 16 | type_ln = post 17 | type_short_path = hop 18 | add_time_in_day = False 19 | add_day_in_week = False 20 | far_mask_delta=5 21 | dtw_delta=5 22 | time_intervals=1800 23 | cand_key_days=14 24 | n_cluster=16 25 | cluster_max_iter=5 26 | cluster_method=kshape -------------------------------------------------------------------------------- /conf/PDFormer/PEMS03.conf: -------------------------------------------------------------------------------- 1 | [model] 2 | embed_dim = 64 3 | skip_dim = 256 4 | lape_dim = 32 5 | geo_num_heads = 4 6 | sem_num_heads = 2 7 | t_num_heads = 2 8 | mlp_ratio = 4 9 | qkv_bias = True 10 | drop = 0 11 | attn_drop = 0 12 | drop_path = 0.3 13 | s_attn_size = 3 14 | t_attn_size = 1 15 | enc_depth = 6 16 | type_ln = post 17 | type_short_path = hop 18 | add_time_in_day = False 19 | add_day_in_week = False 20 | far_mask_delta=5 21 | dtw_delta=5 22 | time_intervals=300 23 | cand_key_days=14 24 | n_cluster=16 25 | cluster_max_iter=5 26 | cluster_method=kshape -------------------------------------------------------------------------------- /conf/PDFormer/PEMS04.conf: -------------------------------------------------------------------------------- 1 | [model] 2 | embed_dim = 64 3 | skip_dim = 256 4 | lape_dim = 32 5 | geo_num_heads = 4 6 | sem_num_heads = 2 7 | t_num_heads = 2 8 | mlp_ratio = 4 9 | qkv_bias = True 10 | drop = 0 11 | attn_drop = 0 12 | drop_path = 0.3 13 | s_attn_size = 3 14 | t_attn_size = 1 15 | enc_depth = 6 16 | type_ln = post 17 | type_short_path = hop 18 | add_time_in_day = False 19 | add_day_in_week = False 20 | far_mask_delta=5 21 | dtw_delta=5 22 | time_intervals=300 23 | cand_key_days=14 24 | n_cluster=16 25 | cluster_max_iter=5 26 | cluster_method=kshape -------------------------------------------------------------------------------- /conf/PDFormer/PEMS07.conf: -------------------------------------------------------------------------------- 1 | [model] 2 | embed_dim = 64 3 | skip_dim = 256 4 | lape_dim = 32 5 | geo_num_heads = 4 6 | sem_num_heads = 2 7 | t_num_heads = 2 8 | mlp_ratio = 4 9 | qkv_bias = True 10 | drop = 0 11 | attn_drop = 0 12 | drop_path = 0.3 13 | s_attn_size = 3 14 | t_attn_size = 1 15 | enc_depth = 6 16 | type_ln = post 17 | type_short_path = hop 18 | add_time_in_day = False 19 | add_day_in_week = False 20 | far_mask_delta=5 21 | dtw_delta=5 22 | time_intervals=300 23 | cand_key_days=14 24 | n_cluster=16 25 | cluster_max_iter=5 26 | cluster_method=kshape -------------------------------------------------------------------------------- /conf/PDFormer/PEMS07M.conf: -------------------------------------------------------------------------------- 1 | [model] 2 | embed_dim = 64 3 | skip_dim = 256 4 | lape_dim = 32 5 | geo_num_heads = 4 6 | sem_num_heads = 2 7 | t_num_heads = 2 8 | mlp_ratio = 4 9 | qkv_bias = True 10 | drop = 0 11 | attn_drop = 0 12 | drop_path = 0.3 13 | s_attn_size = 3 14 | t_attn_size = 1 15 | enc_depth = 6 16 | type_ln = post 17 | type_short_path = hop 18 | add_time_in_day = False 19 | add_day_in_week = False 20 | far_mask_delta=5 21 | dtw_delta=5 22 | time_intervals=300 23 | cand_key_days=14 24 | n_cluster=16 25 | cluster_max_iter=5 26 | cluster_method=kshape -------------------------------------------------------------------------------- /conf/PDFormer/PEMS08.conf: -------------------------------------------------------------------------------- 1 | [model] 2 | embed_dim = 64 3 | skip_dim = 256 4 | lape_dim = 32 5 | geo_num_heads = 4 6 | sem_num_heads = 2 7 | t_num_heads = 2 8 | mlp_ratio = 4 9 | qkv_bias = True 10 | drop = 0 11 | attn_drop = 0 12 | drop_path = 0.3 13 | s_attn_size = 3 14 | t_attn_size = 1 15 | enc_depth = 6 16 | type_ln = post 17 | type_short_path = hop 18 | add_time_in_day = False 19 | add_day_in_week = False 20 | far_mask_delta=5 21 | dtw_delta=5 22 | time_intervals=300 23 | cand_key_days=21 24 | n_cluster=16 25 | cluster_max_iter=5 26 | cluster_method=kshape -------------------------------------------------------------------------------- /conf/PDFormer/chengdu_didi.conf: -------------------------------------------------------------------------------- 1 | [model] 2 | embed_dim = 64 3 | skip_dim = 256 4 | lape_dim = 32 5 | geo_num_heads = 4 6 | sem_num_heads = 2 7 | t_num_heads = 2 8 | mlp_ratio = 4 9 | qkv_bias = True 10 | drop = 0 11 | attn_drop = 0 12 | drop_path = 0.3 13 | s_attn_size = 3 14 | t_attn_size = 1 15 | enc_depth = 6 16 | type_ln = post 17 | type_short_path = hop 18 | add_time_in_day = False 19 | add_day_in_week = False 20 | far_mask_delta=5 21 | dtw_delta=5 22 | time_intervals=600 23 | cand_key_days=14 24 | n_cluster=16 25 | cluster_max_iter=5 26 | cluster_method=kshape -------------------------------------------------------------------------------- /conf/ST-WA/CA_District5.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 211 6 | lag = 12 7 | horizon = 12 8 | val_ratio = 0.2 9 | test_ratio = 0.2 10 | # adj_filename = ../data/PEMS08/PEMS08.csv 11 | id_filename = None 12 | 13 | [model] 14 | in_dim = 1 15 | out_dim = 1 16 | channels = 16 17 | dynamic = True 18 | memory_size = 16 19 | 20 | [train] 21 | # seed = 0 22 | # learning_rate = 0.001 23 | # batch_size = 16 24 | # epochs = 100 25 | # grad_norm = False 26 | # max_grad_norm = 5 27 | # save = ./garage/metr 28 | # expid = 1 29 | # log_step = 20 30 | # early_stop_patience = 15 31 | 32 | [test] 33 | mae_thresh = 0.0 34 | mape_thresh = 0.0 35 | -------------------------------------------------------------------------------- /conf/ST-WA/NYC_BIKE.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 250 6 | lag = 12 7 | horizon = 12 8 | val_ratio = 0.2 9 | test_ratio = 0.2 10 | # adj_filename = ../data/PEMS08/PEMS08.csv 11 | id_filename = None 12 | 13 | [model] 14 | in_dim = 1 15 | out_dim = 1 16 | channels = 16 17 | dynamic = True 18 | memory_size = 16 19 | 20 | [train] 21 | # seed = 10 22 | # learning_rate = 0.001 23 | # batch_size = 16 24 | # epochs = 100 25 | # grad_norm = False 26 | # max_grad_norm = 5 27 | # save = ./garage/metr 28 | # expid = 1 29 | # log_step = 20 30 | # early_stop_patience = 15 31 | 32 | [test] 33 | mae_thresh = 0.0 34 | mape_thresh = 0.0 35 | -------------------------------------------------------------------------------- /conf/ST-WA/PEMS07M.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 228 6 | lag = 12 7 | horizon = 12 8 | val_ratio = 0.2 9 | test_ratio = 0.2 10 | # adj_filename = ../data/PEMS08/PEMS08.csv 11 | id_filename = None 12 | 13 | [model] 14 | in_dim = 1 15 | out_dim = 1 16 | channels = 16 17 | dynamic = True 18 | memory_size = 16 19 | 20 | [train] 21 | # seed = 10 22 | # learning_rate = 0.001 23 | # batch_size = 16 24 | # epochs = 100 25 | # grad_norm = False 26 | # max_grad_norm = 5 27 | # save = ./garage/metr 28 | # expid = 1 29 | # log_step = 20 30 | # early_stop_patience = 15 31 | 32 | [test] 33 | mae_thresh = 0.0 34 | mape_thresh = 0.0 35 | -------------------------------------------------------------------------------- /conf/ST-WA/chengdu_didi.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 524 6 | lag = 12 7 | horizon = 12 8 | val_ratio = 0.2 9 | test_ratio = 0.2 10 | # adj_filename = ../data/PEMS08/PEMS08.csv 11 | id_filename = None 12 | 13 | [model] 14 | in_dim = 1 15 | out_dim = 1 16 | channels = 16 17 | dynamic = True 18 | memory_size = 16 19 | 20 | [train] 21 | # seed = 0 22 | # learning_rate = 0.001 23 | # batch_size = 16 24 | # epochs = 100 25 | # grad_norm = False 26 | # max_grad_norm = 5 27 | # save = ./garage/metr 28 | # expid = 1 29 | # log_step = 20 30 | # early_stop_patience = 15 31 | 32 | [test] 33 | mae_thresh = 0.0 34 | mape_thresh = 0.0 35 | -------------------------------------------------------------------------------- /conf/STFGNN/CA_District5.conf: -------------------------------------------------------------------------------- 1 | 2 | [data] 3 | num_nodes = 211 4 | window = 12 5 | lag = 12 6 | horizon = 12 7 | order = 1 8 | period = 288 9 | sparsity = 0.01 10 | 11 | [model] 12 | hidden_dims = [[64, 64, 64], [64, 64, 64], [64, 64, 64]] 13 | first_layer_embedding_size = 64 14 | out_layer_dim = 128 15 | output_dim = 1 16 | strides = 4 17 | temporal_emb = True 18 | spatial_emb = True 19 | use_mask = False 20 | activation = GLU 21 | module_type = individual 22 | 23 | -------------------------------------------------------------------------------- /conf/STFGNN/NYC_BIKE.conf: -------------------------------------------------------------------------------- 1 | 2 | [data] 3 | num_nodes = 250 4 | window = 12 5 | lag = 12 6 | horizon = 12 7 | order = 1 8 | period = 48 9 | sparsity = 0.01 10 | 11 | [model] 12 | hidden_dims = [[64, 64, 64], [64, 64, 64], [64, 64, 64]] 13 | first_layer_embedding_size = 64 14 | out_layer_dim = 128 15 | output_dim = 1 16 | strides = 4 17 | temporal_emb = True 18 | spatial_emb = True 19 | use_mask = False 20 | activation = GLU 21 | module_type = individual 22 | 23 | -------------------------------------------------------------------------------- /conf/STFGNN/PEMS07M.conf: -------------------------------------------------------------------------------- 1 | 2 | [data] 3 | num_nodes = 228 4 | window = 12 5 | lag = 12 6 | horizon = 12 7 | order = 1 8 | period = 288 9 | sparsity = 0.01 10 | 11 | [model] 12 | hidden_dims = [[64, 64, 64], [64, 64, 64], [64, 64, 64]] 13 | first_layer_embedding_size = 64 14 | out_layer_dim = 128 15 | output_dim = 1 16 | strides = 4 17 | temporal_emb = True 18 | spatial_emb = True 19 | use_mask = False 20 | activation = GLU 21 | module_type = individual 22 | 23 | -------------------------------------------------------------------------------- /conf/STFGNN/chengdu_didi.conf: -------------------------------------------------------------------------------- 1 | 2 | [data] 3 | num_nodes = 524 4 | window = 12 5 | lag = 12 6 | horizon = 12 7 | order = 1 8 | period = 288 9 | sparsity = 0.01 10 | 11 | [model] 12 | hidden_dims = [[64, 64, 64], [64, 64, 64], [64, 64, 64]] 13 | first_layer_embedding_size = 64 14 | out_layer_dim = 128 15 | output_dim = 1 16 | strides = 4 17 | temporal_emb = True 18 | spatial_emb = True 19 | use_mask = False 20 | activation = GLU 21 | module_type = individual 22 | 23 | -------------------------------------------------------------------------------- /conf/STGCN/CA_District5.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 211 3 | input_window = 12 4 | output_window = 12 5 | 6 | [model] 7 | Ks = 3 8 | Kt = 3 9 | blocks1 = [64, 32, 128] 10 | drop_prob = 0 11 | outputl_ks = 3 12 | 13 | -------------------------------------------------------------------------------- /conf/STGCN/NYC_BIKE.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 250 3 | input_window = 12 4 | output_window = 12 5 | 6 | [model] 7 | Ks = 3 8 | Kt = 3 9 | blocks1 = [64, 32, 128] 10 | drop_prob = 0 11 | outputl_ks = 3 12 | -------------------------------------------------------------------------------- /conf/STGCN/PEMS03.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 358 3 | input_window = 12 4 | output_window = 12 5 | 6 | [model] 7 | Ks = 3 8 | Kt = 3 9 | blocks1 = [64, 32, 128] 10 | drop_prob = 0 11 | outputl_ks = 3 12 | -------------------------------------------------------------------------------- /conf/STGCN/PEMS04.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 307 3 | input_window = 12 4 | output_window = 12 5 | 6 | [model] 7 | Ks = 3 8 | Kt = 3 9 | blocks1 = [64, 32, 128] 10 | drop_prob = 0 11 | outputl_ks = 3 12 | 13 | -------------------------------------------------------------------------------- /conf/STGCN/PEMS07.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 883 3 | input_window = 12 4 | output_window = 12 5 | 6 | [model] 7 | Ks = 3 8 | Kt = 3 9 | blocks1 = [64, 32, 128] 10 | drop_prob = 0 11 | outputl_ks = 3 12 | 13 | -------------------------------------------------------------------------------- /conf/STGCN/PEMS07M.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 228 3 | input_window = 12 4 | output_window = 12 5 | 6 | [model] 7 | Ks = 3 8 | Kt = 3 9 | blocks1 = [64, 32, 128] 10 | drop_prob = 0 11 | outputl_ks = 3 12 | -------------------------------------------------------------------------------- /conf/STGCN/PEMS08.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 170 3 | input_window = 12 4 | output_window = 12 5 | 6 | [model] 7 | Ks = 3 8 | Kt = 3 9 | blocks1 = [64, 32, 128] 10 | drop_prob = 0 11 | outputl_ks = 3 12 | -------------------------------------------------------------------------------- /conf/STGCN/chengdu_didi.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 266 3 | input_window = 12 4 | output_window = 12 5 | 6 | [model] 7 | Ks = 3 8 | Kt = 3 9 | blocks1 = [64, 32, 128] 10 | drop_prob = 0 11 | outputl_ks = 3 12 | 13 | -------------------------------------------------------------------------------- /conf/STSGCN/CA_District5.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 211 3 | input_window = 12 4 | output_window = 12 5 | 6 | [model] 7 | filter_list = [[64, 64, 64], [64, 64, 64], [64, 64, 64], [64, 64, 64]] 8 | rho = 1 9 | feature_dim = 64 10 | module_type = individual 11 | activation = GLU 12 | temporal_emb = True 13 | spatial_emb = True 14 | use_mask = False 15 | steps = 3 16 | first_layer_embedding_size = 64 -------------------------------------------------------------------------------- /conf/STSGCN/NYC_BIKE.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 250 3 | input_window = 12 4 | output_window = 12 5 | 6 | [model] 7 | filter_list = [[64, 64, 64], [64, 64, 64], [64, 64, 64], [64, 64, 64]] 8 | rho = 1 9 | feature_dim = 64 10 | module_type = individual 11 | activation = GLU 12 | temporal_emb = True 13 | spatial_emb = True 14 | use_mask = False 15 | steps = 3 16 | first_layer_embedding_size = 64 -------------------------------------------------------------------------------- /conf/STSGCN/PEMS07M.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 228 3 | input_window = 12 4 | output_window = 12 5 | 6 | [model] 7 | filter_list = [[64, 64, 64], [64, 64, 64], [64, 64, 64], [64, 64, 64]] 8 | rho = 1 9 | feature_dim = 64 10 | module_type = individual 11 | activation = GLU 12 | temporal_emb = True 13 | spatial_emb = True 14 | use_mask = False 15 | steps = 3 16 | first_layer_embedding_size = 64 -------------------------------------------------------------------------------- /conf/STSGCN/chengdu_didi.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 524 3 | input_window = 12 4 | output_window = 12 5 | 6 | [model] 7 | filter_list = [[64, 64, 64], [64, 64, 64], [64, 64, 64], [64, 64, 64]] 8 | rho = 1 9 | feature_dim = 64 10 | module_type = individual 11 | activation = GLU 12 | temporal_emb = True 13 | spatial_emb = True 14 | use_mask = False 15 | steps = 3 16 | first_layer_embedding_size = 64 -------------------------------------------------------------------------------- /conf/TGCN/CA_District5.conf: -------------------------------------------------------------------------------- 1 | 2 | [data] 3 | num_nodes = 211 4 | input_window = 12 5 | output_window = 12 6 | 7 | [model] 8 | rnn_units = 100 9 | lam = 0.0015 10 | output_dim = 1 11 | 12 | -------------------------------------------------------------------------------- /conf/TGCN/NYC_BIKE.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | device = cuda:0 3 | 4 | [data] 5 | num_nodes = 250 6 | input_window = 12 7 | output_window = 12 8 | 9 | [model] 10 | rnn_units = 100 11 | lam = 0.0015 12 | output_dim = 2 13 | 14 | -------------------------------------------------------------------------------- /conf/TGCN/PEMS07M.conf: -------------------------------------------------------------------------------- 1 | 2 | [data] 3 | num_nodes = 228 4 | input_window = 12 5 | output_window = 12 6 | 7 | [model] 8 | rnn_units = 100 9 | lam = 0.0015 10 | output_dim = 1 11 | 12 | -------------------------------------------------------------------------------- /conf/TGCN/chengdu_didi.conf: -------------------------------------------------------------------------------- 1 | 2 | [data] 3 | num_nodes = 524 4 | input_window = 12 5 | output_window = 12 6 | 7 | [model] 8 | rnn_units = 100 9 | lam = 0.0015 10 | output_dim = 1 11 | 12 | -------------------------------------------------------------------------------- /lib/TrainInits.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | def init_seed(seed, seed_mode): 6 | ''' 7 | Disable cudnn to maximize reproducibility 8 | ''' 9 | if seed_mode: 10 | torch.cuda.cudnn_enabled = False 11 | torch.backends.cudnn.deterministic = True 12 | # random.seed(seed) 13 | # np.random.seed(seed) 14 | # torch.manual_seed(seed) 15 | # torch.cuda.manual_seed(seed) 16 | 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | random.seed(seed) 21 | 22 | def init_device(opt): 23 | if torch.cuda.is_available(): 24 | opt.cuda = True 25 | torch.cuda.set_device(int(opt.device[5])) 26 | else: 27 | opt.cuda = False 28 | opt.device = 'cpu' 29 | return opt 30 | 31 | def init_optim(model, opt): 32 | ''' 33 | Initialize optimizer 34 | ''' 35 | return torch.optim.Adam(params=model.parameters(),lr=opt.lr_init) 36 | 37 | def init_lr_scheduler(optim, opt): 38 | ''' 39 | Initialize the learning rate scheduler 40 | ''' 41 | #return torch.optim.lr_scheduler.StepLR(optimizer=optim,gamma=opt.lr_scheduler_rate,step_size=opt.lr_scheduler_step) 42 | return torch.optim.lr_scheduler.MultiStepLR(optimizer=optim, milestones=opt.lr_decay_steps, 43 | gamma = opt.lr_scheduler_rate) 44 | 45 | def print_model_parameters(model, only_num = True): 46 | print('*****************Model Parameter*****************') 47 | if not only_num: 48 | for name, param in model.named_parameters(): 49 | print(name, param.shape, param.requires_grad) 50 | total_num = sum([param.nelement() for param in model.parameters()]) 51 | print('Total params num: {}'.format(total_num)) 52 | print('*****************Finish Parameter****************') 53 | 54 | def get_memory_usage(device): 55 | allocated_memory = torch.cuda.memory_allocated(device) / (1024*1024.) 56 | cached_memory = torch.cuda.memory_cached(device) / (1024*1024.) 57 | print('Allocated Memory: {:.2f} MB, Cached Memory: {:.2f} MB'.format(allocated_memory, cached_memory)) 58 | return allocated_memory, cached_memory -------------------------------------------------------------------------------- /lib/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from datetime import datetime 4 | 5 | def get_logger(root, name=None, debug=True): 6 | #when debug is true, show DEBUG and INFO in screen 7 | #when debug is false, show DEBUG in file and info in both screen&file 8 | #INFO will always be in screen 9 | # create a logger 10 | logger = logging.getLogger(name) 11 | #critical > error > warning > info > debug > notset 12 | logger.setLevel(logging.DEBUG) 13 | 14 | # define the formate 15 | formatter = logging.Formatter('%(asctime)s: %(message)s', "%Y-%m-%d %H:%M") 16 | # create another handler for output log to console 17 | console_handler = logging.StreamHandler() 18 | if debug: 19 | console_handler.setLevel(logging.DEBUG) 20 | else: 21 | console_handler.setLevel(logging.INFO) 22 | # create a handler for write log to file 23 | logfile = os.path.join(root, 'run.log') 24 | print('Creat Log File in: ', logfile) 25 | file_handler = logging.FileHandler(logfile, mode='w') 26 | file_handler.setLevel(logging.DEBUG) 27 | file_handler.setFormatter(formatter) 28 | console_handler.setFormatter(formatter) 29 | # add Handler to logger 30 | logger.addHandler(console_handler) 31 | if not debug: 32 | logger.addHandler(file_handler) 33 | return logger 34 | 35 | 36 | if __name__ == '__main__': 37 | time = datetime.now().strftime('%Y%m%d%H%M%S') 38 | print(time) 39 | logger = get_logger('./log.txt', debug=True) 40 | logger.debug('this is a {} debug message'.format(1)) 41 | logger.info('this is an info message') 42 | logger.debug('this is a debug message') 43 | logger.info('this is an info message') 44 | logger.debug('this is a debug message') 45 | logger.info('this is an info message') -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Always evaluate the model with MAE, RMSE, MAPE, RRSE, PNBI, and oPNBI. 3 | Why add mask to MAE and RMSE? 4 | Filter the 0 that may be caused by error (such as loop sensor) 5 | Why add mask to MAPE and MARE? 6 | Ignore very small values (e.g., 0.5/0.5=100%) 7 | ''' 8 | import numpy as np 9 | import torch 10 | 11 | def MAE_torch(pred, true, mask_value=None): 12 | if mask_value != None: 13 | mask = torch.gt(true, mask_value) 14 | pred = torch.masked_select(pred, mask) 15 | true = torch.masked_select(true, mask) 16 | mae_loss = torch.abs(true - pred) 17 | # print(mae_loss[mae_loss>3].shape, mae_loss[mae_loss<1].shape, mae_loss.shape) 18 | return torch.mean(mae_loss), mae_loss 19 | 20 | def huber_loss(pred, true, mask_value=None, delta=1.0): 21 | if mask_value != None: 22 | mask = torch.gt(true, mask_value) 23 | pred = torch.masked_select(pred, mask) 24 | true = torch.masked_select(true, mask) 25 | residual = torch.abs(pred - true) 26 | condition = torch.le(residual, delta) 27 | small_res = 0.5 * torch.square(residual) 28 | large_res = delta * residual - 0.5 * delta * delta 29 | return torch.mean(torch.where(condition, small_res, large_res)), None 30 | # lo = torch.nn.SmoothL1Loss() 31 | # return lo(preds, labels) 32 | 33 | def MSE_torch(pred, true, mask_value=None): 34 | if mask_value != None: 35 | mask = torch.gt(true, mask_value) 36 | pred = torch.masked_select(pred, mask) 37 | true = torch.masked_select(true, mask) 38 | return torch.mean((pred - true) ** 2) 39 | 40 | def RMSE_torch(pred, true, mask_value=None): 41 | if mask_value != None: 42 | mask = torch.gt(true, mask_value) 43 | pred = torch.masked_select(pred, mask) 44 | true = torch.masked_select(true, mask) 45 | return torch.sqrt(torch.mean((pred - true) ** 2)) 46 | 47 | def RRSE_torch(pred, true, mask_value=None): 48 | if mask_value != None: 49 | mask = torch.gt(true, mask_value) 50 | pred = torch.masked_select(pred, mask) 51 | true = torch.masked_select(true, mask) 52 | return torch.sqrt(torch.sum((pred - true) ** 2)) / torch.sqrt(torch.sum((pred - true.mean()) ** 2)) 53 | 54 | def CORR_torch(pred, true, mask_value=None): 55 | #input B, T, N, D or B, N, D or B, N 56 | if len(pred.shape) == 2: 57 | pred = pred.unsqueeze(dim=1).unsqueeze(dim=1) 58 | true = true.unsqueeze(dim=1).unsqueeze(dim=1) 59 | elif len(pred.shape) == 3: 60 | pred = pred.transpose(1, 2).unsqueeze(dim=1) 61 | true = true.transpose(1, 2).unsqueeze(dim=1) 62 | elif len(pred.shape) == 4: 63 | #B, T, N, D -> B, T, D, N 64 | pred = pred.transpose(2, 3) 65 | true = true.transpose(2, 3) 66 | else: 67 | raise ValueError 68 | dims = (0, 1, 2) 69 | pred_mean = pred.mean(dim=dims) 70 | true_mean = true.mean(dim=dims) 71 | pred_std = pred.std(dim=dims) 72 | true_std = true.std(dim=dims) 73 | correlation = ((pred - pred_mean)*(true - true_mean)).mean(dim=dims) / (pred_std*true_std) 74 | index = (true_std != 0) 75 | correlation = (correlation[index]).mean() 76 | return correlation 77 | 78 | 79 | def MAPE_torch(pred, true, mask_value=None): 80 | if mask_value != None: 81 | mask = torch.gt(true, mask_value) 82 | pred = torch.masked_select(pred, mask) 83 | true = torch.masked_select(true, mask) 84 | # print(true[true<1].shape, true[true<0.0001].shape, true[true==0].shape) 85 | # print(true) 86 | return torch.mean(torch.abs(torch.div((true - pred), true))) 87 | 88 | def PNBI_torch(pred, true, mask_value=None): 89 | if mask_value != None: 90 | mask = torch.gt(true, mask_value) 91 | pred = torch.masked_select(pred, mask) 92 | true = torch.masked_select(true, mask) 93 | indicator = torch.gt(pred - true, 0).float() 94 | return indicator.mean() 95 | 96 | def oPNBI_torch(pred, true, mask_value=None): 97 | if mask_value != None: 98 | mask = torch.gt(true, mask_value) 99 | pred = torch.masked_select(pred, mask) 100 | true = torch.masked_select(true, mask) 101 | bias = (true+pred) / (2*true) 102 | return bias.mean() 103 | 104 | def MARE_torch(pred, true, mask_value=None): 105 | if mask_value != None: 106 | mask = torch.gt(true, mask_value) 107 | pred = torch.masked_select(pred, mask) 108 | true = torch.masked_select(true, mask) 109 | return torch.div(torch.sum(torch.abs((true - pred))), torch.sum(true)) 110 | 111 | def SMAPE_torch(pred, true, mask_value=None): 112 | if mask_value != None: 113 | mask = torch.gt(true, mask_value) 114 | pred = torch.masked_select(pred, mask) 115 | true = torch.masked_select(true, mask) 116 | return torch.mean(torch.abs(true-pred)/(torch.abs(true)+torch.abs(pred))) 117 | 118 | 119 | def MAE_np(pred, true, mask_value=None): 120 | if mask_value != None: 121 | mask = np.where(true > (mask_value), True, False) 122 | true = true[mask] 123 | pred = pred[mask] 124 | MAE = np.mean(np.absolute(pred-true)) 125 | return MAE 126 | 127 | def RMSE_np(pred, true, mask_value=None): 128 | if mask_value != None: 129 | mask = np.where(true > (mask_value), True, False) 130 | true = true[mask] 131 | pred = pred[mask] 132 | RMSE = np.sqrt(np.mean(np.square(pred-true))) 133 | return RMSE 134 | 135 | #Root Relative Squared Error 136 | def RRSE_np(pred, true, mask_value=None): 137 | if mask_value != None: 138 | mask = np.where(true > (mask_value), True, False) 139 | true = true[mask] 140 | pred = pred[mask] 141 | mean = true.mean() 142 | return np.divide(np.sqrt(np.sum((pred-true) ** 2)), np.sqrt(np.sum((true-mean) ** 2))) 143 | 144 | def MAPE_np(pred, true, mask_value=None): 145 | if mask_value != None: 146 | mask = np.where(true > (mask_value), True, False) 147 | true = true[mask] 148 | pred = pred[mask] 149 | return np.mean(np.absolute(np.divide((true - pred), true))) 150 | 151 | def PNBI_np(pred, true, mask_value=None): 152 | #if PNBI=0, all pred are smaller than true 153 | #if PNBI=1, all pred are bigger than true 154 | if mask_value != None: 155 | mask = np.where(true > (mask_value), True, False) 156 | true = true[mask] 157 | pred = pred[mask] 158 | bias = pred-true 159 | indicator = np.where(bias>0, True, False) 160 | return indicator.mean() 161 | 162 | def oPNBI_np(pred, true, mask_value=None): 163 | #if oPNBI>1, pred are bigger than true 164 | #if oPNBI<1, pred are smaller than true 165 | #however, this metric is too sentive to small values. Not good! 166 | if mask_value != None: 167 | mask = np.where(true > (mask_value), True, False) 168 | true = true[mask] 169 | pred = pred[mask] 170 | bias = (true + pred) / (2 * true) 171 | return bias.mean() 172 | 173 | def MARE_np(pred, true, mask_value=None): 174 | if mask_value != None: 175 | mask = np.where(true> (mask_value), True, False) 176 | true = true[mask] 177 | pred = pred[mask] 178 | return np.divide(np.sum(np.absolute((true - pred))), np.sum(true)) 179 | 180 | def CORR_np(pred, true, mask_value=None): 181 | #input B, T, N, D or B, N, D or B, N 182 | if len(pred.shape) == 2: 183 | #B, N 184 | pred = pred.unsqueeze(dim=1).unsqueeze(dim=1) 185 | true = true.unsqueeze(dim=1).unsqueeze(dim=1) 186 | elif len(pred.shape) == 3: 187 | #np.transpose include permute, B, T, N 188 | pred = np.expand_dims(pred.transpose(0, 2, 1), axis=1) 189 | true = np.expand_dims(true.transpose(0, 2, 1), axis=1) 190 | elif len(pred.shape) == 4: 191 | #B, T, N, D -> B, T, D, N 192 | pred = pred.transpose(0, 1, 2, 3) 193 | true = true.transpose(0, 1, 2, 3) 194 | else: 195 | raise ValueError 196 | dims = (0, 1, 2) 197 | pred_mean = pred.mean(axis=dims) 198 | true_mean = true.mean(axis=dims) 199 | pred_std = pred.std(axis=dims) 200 | true_std = true.std(axis=dims) 201 | correlation = ((pred - pred_mean)*(true - true_mean)).mean(axis=dims) / (pred_std*true_std) 202 | index = (true_std != 0) 203 | correlation = (correlation[index]).mean() 204 | return correlation 205 | 206 | def All_Metrics(pred, true, mask1, mask2): 207 | #mask1 filter the very small value, mask2 filter the value lower than a defined threshold 208 | assert type(pred) == type(true) 209 | if type(pred) == np.ndarray: 210 | mae = MAE_np(pred, true, mask1) 211 | rmse = RMSE_np(pred, true, mask1) 212 | mape = MAPE_np(pred, true, mask2) 213 | rrse = RRSE_np(pred, true, mask1) 214 | # corr = 0 215 | corr = CORR_np(pred, true, mask1) 216 | #pnbi = PNBI_np(pred, true, mask1) 217 | #opnbi = oPNBI_np(pred, true, mask2) 218 | elif type(pred) == torch.Tensor: 219 | mae, _ = MAE_torch(pred, true, mask1) 220 | rmse = RMSE_torch(pred, true, mask1) 221 | mape = MAPE_torch(pred, true, mask2) 222 | rrse = RRSE_torch(pred, true, mask1) 223 | corr = CORR_torch(pred, true, mask1) 224 | #pnbi = PNBI_torch(pred, true, mask1) 225 | #opnbi = oPNBI_torch(pred, true, mask2) 226 | else: 227 | raise TypeError 228 | return mae, rmse, mape, rrse, corr 229 | 230 | def SIGIR_Metrics(pred, true, mask1, mask2): 231 | rrse = RRSE_torch(pred, true, mask1) 232 | corr = CORR_torch(pred, true, 0) 233 | return rrse, corr 234 | 235 | if __name__ == '__main__': 236 | pred = torch.Tensor([1, 2, 3,4]) 237 | true = torch.Tensor([2, 1, 4,5]) 238 | print(All_Metrics(pred, true, None, None)) 239 | 240 | -------------------------------------------------------------------------------- /lib/predifineGraph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import pickle 4 | import pandas as pd 5 | import torch 6 | import torch.nn as nn 7 | 8 | def pre_graph_dict(args): 9 | A_dict_np = {} 10 | A_dict = {} 11 | lap_dict = {} 12 | node_dict = {} 13 | node_dict['PEMS08'], node_dict['PEMS07'], node_dict['PEMS04'], node_dict['PEMS03'] = 170, 883, 307, 358 14 | for data_graph in args.dataset_graph: 15 | if data_graph == 'PEMS08' or data_graph == 'PEMS04' or data_graph == 'PEMS07': 16 | A, Distance = get_adjacency_matrix(distance_df_filename='../data/' + data_graph + '/' + data_graph + '.csv', 17 | num_of_vertices=node_dict[data_graph]) 18 | elif data_graph == 'PEMS03': 19 | A, Distance = get_adjacency_matrix( 20 | distance_df_filename='../data/' + data_graph + '/' + data_graph + '.csv', 21 | num_of_vertices=node_dict[data_graph], id_filename='../data/' + data_graph + '/' + data_graph + '.txt') 22 | elif data_graph == 'PEMS07M': 23 | A = weight_matrix('../data/' + data_graph + '/' + data_graph + '.csv').astype(np.float32) 24 | A = A + np.eye(A.shape[0]) 25 | elif data_graph == 'NYC_BIKE': 26 | A = pd.read_csv('../data/' + data_graph + '/' + data_graph + '.csv', header=None).values.astype(np.float32) 27 | elif data_graph == 'chengdu_didi': 28 | A = np.load('../data/' + data_graph + '/' + 'matrix.npy').astype(np.float32) 29 | elif data_graph == 'CA_District5': 30 | A = np.load('../data/' + data_graph + '/' + data_graph + '.npy').astype(np.float32) 31 | else: 32 | sensor_ids, sensor_id_to_ind, A = load_pickle(pickle_file='./data/' + data_graph + '/' + 'adj_mx.pkl') 33 | lpls = cal_lape(A.copy()) 34 | lpls = torch.FloatTensor(lpls).to(args.device) 35 | if not args.use_lpls: 36 | nn.init.xavier_uniform_(lpls) 37 | lap_dict[data_graph] = lpls 38 | A = get_normalized_adj(A) 39 | A_dict_np[data_graph] = A 40 | A = torch.FloatTensor(A).to(args.device) 41 | A_dict[data_graph] = A 42 | args.A_dict_np = A_dict_np 43 | args.A_dict = A_dict 44 | args.lpls_dict = lap_dict 45 | 46 | def get_adjacency_matrix(distance_df_filename, num_of_vertices, id_filename=None): 47 | ''' 48 | Parameters 49 | ---------- 50 | distance_df_filename: str, path of the csv file contains edges information 51 | 52 | num_of_vertices: int, the number of vertices 53 | 54 | Returns 55 | ---------- 56 | A: np.ndarray, adjacency matrix 57 | 58 | ''' 59 | if 'npy' in distance_df_filename: 60 | 61 | adj_mx = np.load(distance_df_filename) 62 | 63 | return adj_mx, None 64 | 65 | else: 66 | 67 | import csv 68 | 69 | A = np.zeros((int(num_of_vertices), int(num_of_vertices)), 70 | dtype=np.float32) 71 | 72 | distaneA = np.zeros((int(num_of_vertices), int(num_of_vertices)), 73 | dtype=np.float32) 74 | 75 | if id_filename: 76 | 77 | with open(id_filename, 'r') as f: 78 | id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 把节点id(idx)映射成从0开始的索引 79 | 80 | with open(distance_df_filename, 'r') as f: 81 | f.readline() 82 | reader = csv.reader(f) 83 | for row in reader: 84 | if len(row) != 3: 85 | continue 86 | i, j, distance = int(row[0]), int(row[1]), float(row[2]) 87 | A[id_dict[i], id_dict[j]] = 1 88 | distaneA[id_dict[i], id_dict[j]] = distance 89 | return A, distaneA 90 | 91 | else: 92 | 93 | with open(distance_df_filename, 'r') as f: 94 | f.readline() 95 | reader = csv.reader(f) 96 | for row in reader: 97 | if len(row) != 3: 98 | continue 99 | i, j, distance = int(row[0]), int(row[1]), float(row[2]) 100 | A[i, j] = 1 101 | distaneA[i, j] = distance 102 | return A, distaneA 103 | 104 | def load_pickle(pickle_file): 105 | try: 106 | with open(pickle_file, 'rb') as f: 107 | pickle_data = pickle.load(f) 108 | except UnicodeDecodeError as e: 109 | with open(pickle_file, 'rb') as f: 110 | pickle_data = pickle.load(f, encoding='latin1') 111 | except Exception as e: 112 | print('Unable to load data ', pickle_file, ':', e) 113 | raise 114 | return pickle_data 115 | 116 | 117 | 118 | def calculate_scaled_laplacian(adj): 119 | """ 120 | L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2 121 | L' = 2L/lambda - I 122 | 123 | Args: 124 | adj: adj_matrix 125 | 126 | Returns: 127 | np.ndarray: L' 128 | """ 129 | n = adj.shape[0] 130 | d = np.sum(adj, axis=1) # D 131 | lap = np.diag(d) - adj # L=D-A 132 | for i in range(n): 133 | for j in range(n): 134 | if d[i] > 0 and d[j] > 0: 135 | lap[i, j] /= np.sqrt(d[i] * d[j]) 136 | lap[np.isinf(lap)] = 0 137 | lap[np.isnan(lap)] = 0 138 | lam = np.linalg.eigvals(lap).max().real 139 | return 2 * lap / lam - np.eye(n) 140 | 141 | 142 | 143 | def weight_matrix(file_path, sigma2=0.1, epsilon=0.5, scaling=True): 144 | ''' 145 | From STGCN-IJCAI2018 146 | Load weight matrix function. 147 | :param file_path: str, the path of saved weight matrix file. 148 | :param sigma2: float, scalar of matrix W. 149 | :param epsilon: float, thresholds to control the sparsity of matrix W. 150 | :param scaling: bool, whether applies numerical scaling on W. 151 | :return: np.ndarray, [n_route, n_route]. 152 | ''' 153 | try: 154 | W = pd.read_csv(file_path, header=None).values 155 | except FileNotFoundError: 156 | print(f'ERROR: input file was not found in {file_path}.') 157 | 158 | # check whether W is a 0/1 matrix. 159 | if set(np.unique(W)) == {0, 1}: 160 | print('The input graph is a 0/1 matrix; set "scaling" to False.') 161 | scaling = False 162 | 163 | if scaling: 164 | n = W.shape[0] 165 | W = W / 10000. 166 | W2, WMASK = W * W, np.ones([n, n]) - np.identity(n) 167 | # refer to Eq.10 168 | A = np.exp(-W2 / sigma2) * (np.exp(-W2 / sigma2) >= epsilon) * WMASK 169 | return A 170 | else: 171 | return W 172 | 173 | 174 | def first_approx(W, n): 175 | ''' 176 | 1st-order approximation function. 177 | :param W: np.ndarray, [n_route, n_route], weighted adjacency matrix of G. 178 | :param n: int, number of routes / size of graph. 179 | :return: np.ndarray, [n_route, n_route]. 180 | ''' 181 | A = W + np.identity(n) 182 | d = np.sum(A, axis=1) 183 | sinvD = np.sqrt(np.mat(np.diag(d)).I) 184 | # refer to Eq.5 185 | return np.mat(np.identity(n) + sinvD * A * sinvD) 186 | 187 | def get_normalized_adj(A): 188 | """ 189 | Returns the degree normalized adjacency matrix. 190 | """ 191 | A = A + np.diag(np.ones(A.shape[0], dtype=np.float32)) 192 | D = np.array(np.sum(A, axis=1)).reshape((-1,)) 193 | D[D <= 10e-5] = 10e-5 # Prevent infs 194 | diag = np.reciprocal(np.sqrt(D)) 195 | A_wave = np.multiply(np.multiply(diag.reshape((-1, 1)), A), 196 | diag.reshape((1, -1))) 197 | return A_wave 198 | 199 | def asym_adj(adj): 200 | adj = sp.coo_matrix(adj) 201 | rowsum = np.array(adj.sum(1)).flatten() 202 | d_inv = np.power(rowsum, -1).flatten() 203 | d_inv[np.isinf(d_inv)] = 0. 204 | d_mat= sp.diags(d_inv) 205 | return d_mat.dot(adj).astype(np.float32).todense() 206 | 207 | 208 | def idEncode(x, y, col): 209 | return x * col + y 210 | 211 | def constructGraph(row, col): 212 | mx = [-1, 0, 1, 0, -1, -1, 1, 1, 0] 213 | my = [0, -1, 0, 1, -1, 1, -1, 1, 0] 214 | 215 | areaNum = row * col 216 | 217 | def illegal(x, y): 218 | return x < 0 or y < 0 or x >= row or y >= col 219 | 220 | W = np.zeros((areaNum, areaNum)) 221 | for i in range(row): 222 | for j in range(col): 223 | n1 = idEncode(i, j, col) 224 | for k in range(len(mx)): 225 | temx = i + mx[k] 226 | temy = j + my[k] 227 | if illegal(temx, temy): 228 | continue 229 | n2 = idEncode(temx, temy, col) 230 | W[n1, n2] = 1 231 | return W 232 | 233 | 234 | def calculate_normalized_laplacian(adj): 235 | adj = sp.coo_matrix(adj) 236 | d = np.array(adj.sum(1)) 237 | isolated_point_num = np.sum(np.where(d, 0, 1)) 238 | print(f"Number of isolated points: {isolated_point_num}") 239 | d_inv_sqrt = np.power(d, -0.5).flatten() 240 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 241 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 242 | normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 243 | return normalized_laplacian, isolated_point_num 244 | 245 | def cal_lape(adj_mx): 246 | lape_dim = 32 247 | L, isolated_point_num = calculate_normalized_laplacian(adj_mx) 248 | EigVal, EigVec = np.linalg.eig(L.toarray()) 249 | idx = EigVal.argsort() 250 | EigVal, EigVec = EigVal[idx], np.real(EigVec[:, idx]) 251 | 252 | laplacian_pe = EigVec[:, isolated_point_num + 1: lape_dim + isolated_point_num + 1] 253 | return laplacian_pe -------------------------------------------------------------------------------- /model/AGCRN/AGCN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | class AVWGCN(nn.Module): 6 | def __init__(self, dim_in, dim_out, cheb_k, embed_dim): 7 | super(AVWGCN, self).__init__() 8 | self.cheb_k = cheb_k 9 | # self.weights_pool = nn.Parameter(torch.randn(embed_dim, cheb_k, dim_in, dim_out), requires_grad=True) 10 | # self.bias_pool = nn.Parameter(torch.randn(embed_dim, dim_out), requires_grad=True) 11 | self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)) 12 | self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) 13 | def forward(self, x, node_embeddings): 14 | #x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N] 15 | #output shape [B, N, C] 16 | node_num = node_embeddings.shape[0] 17 | supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1) 18 | support_set = [torch.eye(node_num).to(supports.device), supports] 19 | #default cheb_k = 3 20 | for k in range(2, self.cheb_k): 21 | support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2]) 22 | supports = torch.stack(support_set, dim=0) 23 | weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool) #N, cheb_k, dim_in, dim_out 24 | bias = torch.matmul(node_embeddings, self.bias_pool) #N, dim_out 25 | x_g = torch.einsum("knm,bmc->bknc", supports, x) #B, cheb_k, N, dim_in 26 | x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in 27 | x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out 28 | return x_gconv -------------------------------------------------------------------------------- /model/AGCRN/AGCRN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.AGCRN.AGCRNCell import AGCRNCell 4 | 5 | class AVWDCRNN(nn.Module): 6 | def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): 7 | super(AVWDCRNN, self).__init__() 8 | assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.' 9 | self.node_num = node_num 10 | self.input_dim = dim_in 11 | self.out_dim = dim_out 12 | self.num_layers = num_layers 13 | self.dcrnn_cells = nn.ModuleList() 14 | self.dcrnn_cells.append(AGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim)) 15 | for _ in range(1, num_layers): 16 | self.dcrnn_cells.append(AGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim)) 17 | 18 | def forward(self, x, init_state, node_embeddings): 19 | #shape of x: (B, T, N, D) 20 | #shape of init_state: (num_layers, B, N, hidden_dim) 21 | # assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim 22 | seq_length = x.shape[1] 23 | current_inputs = x 24 | output_hidden = [] 25 | for i in range(self.num_layers): 26 | state = init_state[i] 27 | inner_states = [] 28 | for t in range(seq_length): 29 | state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state, node_embeddings) 30 | inner_states.append(state) 31 | output_hidden.append(state) 32 | current_inputs = torch.stack(inner_states, dim=1) 33 | #current_inputs: the outputs of last layer: (B, T, N, hidden_dim) 34 | #output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim) 35 | #last_state: (B, N, hidden_dim) 36 | return current_inputs, output_hidden 37 | 38 | def init_hidden(self, batch_size, node_dataset): 39 | init_states = [] 40 | for i in range(self.num_layers): 41 | init_states.append(self.dcrnn_cells[i].init_hidden_state(batch_size, node_dataset)) 42 | return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim) 43 | 44 | class AGCRN(nn.Module): 45 | def __init__(self, args, dim_in, dim_out, A_dict, dataset_use, data_test, mode): 46 | super(AGCRN, self).__init__() 47 | self.A_dict = A_dict 48 | self.mode = mode 49 | self.num_node = args.num_nodes 50 | self.input_dim = dim_in 51 | self.hidden_dim = args.rnn_units 52 | self.output_dim = dim_out 53 | self.horizon = args.horizon 54 | self.num_layers = args.num_layers 55 | 56 | self.default_graph = args.default_graph 57 | 58 | self.dataset2index = {} 59 | if mode == 'pretrain': 60 | self.neb_pretrain = [] 61 | for i, data_graph in enumerate(dataset_use): 62 | self.dataset2index[data_graph] = i 63 | n_dataset = A_dict[data_graph].shape[0] 64 | self.neb_pretrain.append(nn.Parameter(torch.randn(n_dataset, args.embed_dim).to(args.device), requires_grad=True)) 65 | else: 66 | self.neb_eval = [] 67 | for i, data_graph in enumerate([data_test]): 68 | self.dataset2index[data_graph] = i 69 | n_dataset = A_dict[data_graph].shape[0] 70 | self.neb_eval.append(nn.Parameter(torch.randn(n_dataset, args.embed_dim).to(args.device), requires_grad=True)) 71 | 72 | 73 | # self.node_embeddings = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) 74 | 75 | self.Lin_input = nn.Linear(self.hidden_dim, 1) 76 | 77 | self.encoder = AVWDCRNN(args.num_nodes, self.input_dim, args.rnn_units, args.cheb_k, 78 | args.embed_dim, args.num_layers) 79 | 80 | #predictor 81 | self.end_conv = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) 82 | 83 | def forward(self, source, select_dataset): 84 | #source: B, T_1, N, D 85 | #target: B, T_2, N, D 86 | #supports = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec1.transpose(0,1))), dim=1) 87 | 88 | 89 | init_state = self.encoder.init_hidden(source.shape[0], self.A_dict[select_dataset].shape[0]) 90 | # source = self.Lin_input(source) 91 | if self.mode == 'pretrain': 92 | output, _ = self.encoder(source, init_state, self.neb_pretrain[self.dataset2index[select_dataset]]) #B, T, N, hidden 93 | else: 94 | output, _ = self.encoder(source, init_state, self.neb_eval[self.dataset2index[select_dataset]]) # B, T, N, hidden 95 | output = output[:, -1:, :, :] #B, 1, N, hidden 96 | 97 | #CNN based predictor 98 | output = self.end_conv((output)) #B, T*C, N, 1 99 | output = output.squeeze(-1).reshape(-1, self.horizon, self.output_dim, self.A_dict[select_dataset].shape[0]) 100 | output = output.permute(0, 1, 3, 2) #B, T, N, C 101 | 102 | return output -------------------------------------------------------------------------------- /model/AGCRN/AGCRNCell.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # from AGCN import AVWGCN 4 | from model.AGCRN.AGCN import AVWGCN 5 | 6 | class AGCRNCell(nn.Module): 7 | def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim): 8 | super(AGCRNCell, self).__init__() 9 | self.dim_in = dim_in 10 | self.node_num = node_num 11 | self.hidden_dim = dim_out 12 | self.gate = AVWGCN(dim_in+self.hidden_dim, 2*dim_out, cheb_k, embed_dim) 13 | self.update = AVWGCN(dim_in+self.hidden_dim, dim_out, cheb_k, embed_dim) 14 | 15 | def forward(self, x, state, node_embeddings): 16 | #x: B, num_nodes, input_dim 17 | #state: B, num_nodes, hidden_dim 18 | state = state.to(x.device) 19 | input_and_state = torch.cat((x, state), dim=-1) 20 | z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings)) 21 | z, r = torch.split(z_r, self.hidden_dim, dim=-1) 22 | candidate = torch.cat((x, z*state), dim=-1) 23 | hc = torch.tanh(self.update(candidate, node_embeddings)) 24 | h = r*state + (1-r)*hc 25 | return h 26 | 27 | def init_hidden_state(self, batch_size, node_dataset): 28 | return torch.zeros(batch_size, node_dataset, self.hidden_dim) -------------------------------------------------------------------------------- /model/AGCRN/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import configparser 3 | 4 | def parse_args(DATASET): 5 | config_file = '../conf/AGCRN/{}.conf'.format(DATASET) 6 | print('Read configuration file: %s' % (config_file)) 7 | config = configparser.ConfigParser() 8 | config.read(config_file) 9 | print(config) 10 | #parser 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--dataset', default=DATASET, type=str) 13 | parser.add_argument('--device', default='cuda:0', type=str, help='indices of GPUs') 14 | parser.add_argument('--debug', default=False, type=eval) 15 | parser.add_argument('--cuda', default=True, type=bool) 16 | #data 17 | parser.add_argument('--lag', default=config['data']['lag'], type=int) 18 | parser.add_argument('--horizon', default=config['data']['horizon'], type=int) 19 | parser.add_argument('--num_nodes', default=config['data']['num_nodes'], type=int) 20 | parser.add_argument('--tod', default=config['data']['tod'], type=eval) 21 | parser.add_argument('--normalizer', default=config['data']['normalizer'], type=str) 22 | parser.add_argument('--column_wise', default=config['data']['column_wise'], type=eval) 23 | parser.add_argument('--default_graph', default=config['data']['default_graph'], type=eval) 24 | #model 25 | parser.add_argument('--input_dim', default=config['model']['input_dim'], type=int) 26 | parser.add_argument('--output_dim', default=config['model']['output_dim'], type=int) 27 | parser.add_argument('--embed_dim', default=config['model']['embed_dim'], type=int) 28 | parser.add_argument('--rnn_units', default=config['model']['rnn_units'], type=int) 29 | parser.add_argument('--num_layers', default=config['model']['num_layers'], type=int) 30 | parser.add_argument('--cheb_k', default=config['model']['cheb_order'], type=int) 31 | args, _ = parser.parse_known_args() 32 | return args -------------------------------------------------------------------------------- /model/ASTGCN/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import configparser 4 | from lib.predifineGraph import get_adjacency_matrix, load_pickle, weight_matrix 5 | import torch 6 | import pandas as pd 7 | 8 | def parse_args(DATASET, args_base): 9 | # get configuration 10 | config_file = '../conf/ASTGCN/{}.conf'.format(DATASET) 11 | config = configparser.ConfigParser() 12 | config.read(config_file) 13 | 14 | parser = argparse.ArgumentParser() 15 | # data 16 | parser.add_argument('--num_nodes', type=int, default=config['data']['num_nodes']) 17 | parser.add_argument('--len_input', type=int, default=config['data']['len_input']) 18 | parser.add_argument('--num_for_predict', type=int, default=config['data']['num_for_predict']) 19 | # model 20 | parser.add_argument('--nb_block', type=int, default=config['model']['nb_block']) 21 | parser.add_argument('--K', type=int, default=config['model']['K']) 22 | parser.add_argument('--nb_chev_filter', type=int, default=config['model']['nb_chev_filter']) 23 | parser.add_argument('--nb_time_filter', type=int, default=config['model']['nb_time_filter']) 24 | parser.add_argument('--time_strides', type=int, default=config['model']['time_strides']) 25 | args, _ = parser.parse_known_args() 26 | return args -------------------------------------------------------------------------------- /model/FlashST.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FlashST(nn.Module): 6 | def __init__(self, args): 7 | super(FlashST, self).__init__() 8 | self.num_node = args.num_nodes 9 | self.input_base_dim = args.input_base_dim 10 | self.input_extra_dim = args.input_extra_dim 11 | self.output_dim = args.output_dim 12 | self.his = args.his 13 | self.pred = args.pred 14 | self.embed_dim = args.embed_dim 15 | self.mode = args.mode 16 | self.model = args.model 17 | self.load_pretrain_path = args.load_pretrain_path 18 | self.log_dir = args.log_dir 19 | self.args = args 20 | 21 | if self.mode == 'ori': 22 | dim_in = self.input_base_dim 23 | else: 24 | dim_in = self.embed_dim*4 25 | args.dim_in = dim_in 26 | dim_out = self.output_dim 27 | 28 | if self.model == 'AGCRN': 29 | from model.AGCRN.AGCRN import AGCRN 30 | from model.AGCRN.args import parse_args 31 | args_predictor = parse_args(args.dataset_test) 32 | self.predictor = AGCRN(args_predictor, dim_in, dim_out, args.A_dict, args.dataset_use, args.dataset_test, self.mode) 33 | elif self.model == 'MTGNN': 34 | from model.MTGNN.MTGNN import MTGNN 35 | from model.MTGNN.args import parse_args 36 | args_predictor = parse_args(args.dataset_test) 37 | self.predictor = MTGNN(args_predictor, dim_in, dim_out, args.A_dict, args.dataset_use, args.dataset_test, self.mode) 38 | elif self.model == 'STGCN': 39 | from model.STGCN.stgcn import STGCN 40 | from model.STGCN.args import parse_args 41 | args_predictor = parse_args(args.dataset_test, args) 42 | self.predictor = STGCN(args_predictor, args_predictor.G_dict, args.dataset_use, [args.dataset_test], args.num_nodes, dim_in, dim_out, args.device, self.mode) 43 | elif self.model == 'STSGCN': 44 | from model.STSGCN.STSGCN import STSGCN 45 | from model.STSGCN.args import parse_args 46 | args_predictor = parse_args(args.dataset_test) 47 | self.predictor = STSGCN(args_predictor, args.num_nodes, args.his, dim_in, dim_out, args.A_dict, args.dataset_use, args.dataset_test, args.device, self.mode) 48 | elif self.model == 'ASTGCN': 49 | from model.ASTGCN.ASTGCN import ASTGCN 50 | from model.ASTGCN.args import parse_args 51 | args_predictor = parse_args(args.dataset_test, args) 52 | self.predictor = ASTGCN(args_predictor, args.A_dict[args.dataset_test], args_predictor.num_nodes, args_predictor.len_input, args_predictor.num_for_predict, dim_in, dim_out, args.device) 53 | elif self.model == 'GWN': 54 | from model.GWN.GWN import GWNET 55 | # from GWN.GWNori import gwnet 56 | from model.GWN.args import parse_args 57 | args_predictor = parse_args(args.dataset_test) 58 | self.predictor = GWNET(args_predictor, dim_in, dim_out, args.A_dict, args.dataset_use, args.dataset_test, self.mode) 59 | elif self.model == 'DMSTGCN': 60 | from model.DMSTGCN.DMSTGCN import DMSTGCN 61 | self.predictor = DMSTGCN(args.device, dim_in, args.A_dict, args.dataset_use, args.dataset_test, self.mode) 62 | elif self.model == 'TGCN': 63 | from model.TGCN.TGCN import TGCN 64 | from model.TGCN.args import parse_args 65 | args_predictor = parse_args(args.dataset_test) 66 | self.predictor = TGCN(args_predictor, args.A_dict, args.dataset_test, args.device, dim_in) 67 | elif self.model == 'STFGNN': 68 | from model.STFGNN.STFGNN import STFGNN 69 | from model.STFGNN.args import parse_args 70 | args_predictor = parse_args(args.dataset_test, args) 71 | self.predictor = STFGNN(args_predictor, dim_in) 72 | elif self.model == 'STGODE': 73 | from model.STGODE.STGODE import ODEGCN 74 | from model.STGODE.args import parse_args 75 | args_predictor = parse_args(args.dataset_test, args) 76 | self.predictor = ODEGCN(args.num_nodes, dim_in, args.his, args.pred, args_predictor.A_sp_wave_dict, 77 | args_predictor.A_se_wave_dict, dim_out, args.A_dict, args.dataset_use, args.dataset_test, self.mode, args.device) 78 | elif self.model == 'STWA': 79 | from model.ST_WA.ST_WA import STWA 80 | from model.ST_WA.args import parse_args 81 | args_predictor = parse_args(args.dataset_test) 82 | self.predictor = STWA(args_predictor.device, args_predictor.num_nodes, dim_in, args_predictor.out_dim, 83 | args_predictor.channels, args_predictor.dynamic, args_predictor.lag, 84 | args_predictor.horizon, args_predictor.supports, args_predictor.memory_size, args.A_dict, args.dataset_use, args.dataset_test, self.mode) 85 | elif self.model == 'MSDR': 86 | from model.MSDR.gmsdr_model import GMSDRModel 87 | from model.MSDR.args import parse_args 88 | args_predictor = parse_args(args.dataset_test) 89 | args_predictor.input_dim = dim_in 90 | args_predictor.A_dict = args.A_dict 91 | args_predictor.dataset_use = args.dataset_use 92 | args_predictor.dataset_test = args.dataset_test 93 | args_predictor.mode = args.mode 94 | self.predictor = GMSDRModel(args_predictor) 95 | 96 | elif self.model == 'PDFormer': 97 | from model.PDFormer.PDFformer import PDFormer 98 | from model.PDFormer.args import parse_args 99 | args_predictor = parse_args(args.dataset_test, args) 100 | self.predictor = PDFormer(args_predictor, args) 101 | 102 | if self.mode == 'eval': 103 | for param in self.predictor.parameters(): 104 | param.requires_grad = False 105 | # STGCN # 106 | if self.model == 'STGCN': 107 | for param in self.predictor.st_conv1.ln_eval.parameters(): 108 | param.requires_grad = True 109 | for param in self.predictor.st_conv2.ln_eval.parameters(): 110 | param.requires_grad = True 111 | for param in self.predictor.output.ln_eval.parameters(): 112 | param.requires_grad = True 113 | for param in self.predictor.output.fc_pretrain.parameters(): 114 | param.requires_grad = True 115 | # GWN # 116 | elif self.model == 'GWN': 117 | for param in self.predictor.nodevec1_eval: 118 | param.requires_grad = True 119 | for param in self.predictor.nodevec2_eval: 120 | param.requires_grad = True 121 | for param in self.predictor.end_conv_2.parameters(): 122 | param.requires_grad = True 123 | # MTGNN # 124 | elif self.model == 'MTGNN': 125 | for param in self.predictor.gc_eval.parameters(): 126 | param.requires_grad = True 127 | for param in self.predictor.norm_eval.parameters(): 128 | param.requires_grad = True 129 | # PDFormer # 130 | elif self.model == 'PDFormer': 131 | for param in self.predictor.end_conv1.parameters(): 132 | param.requires_grad = True 133 | for param in self.predictor.end_conv2.parameters(): 134 | param.requires_grad = True 135 | 136 | if (args.mode == 'pretrain' or args.mode == 'ori') and args.xavier: 137 | for p in self.predictor.parameters(): 138 | if p.dim() > 1 and p.requires_grad: 139 | nn.init.xavier_uniform_(p) 140 | else: 141 | nn.init.uniform_(p) 142 | 143 | if self.mode != 'ori': 144 | from PromptNet import PromptNet 145 | self.pretrain_model = PromptNet(args) 146 | 147 | 148 | def forward(self, source, label, select_dataset, batch_seen=None, nadj=None, lpls=None, useGNN=False): 149 | if self.mode == 'ori': 150 | return self.forward_ori(source, label, select_dataset, batch_seen) 151 | else: 152 | return self.forward_pretrain(source, label, select_dataset, batch_seen, nadj, lpls, useGNN) 153 | 154 | def forward_pretrain(self, source, label, select_dataset, batch_seen=None, nadj=None, lpls=None, useGNN=False): 155 | x_prompt_return = self.pretrain_model(source[..., :self.input_base_dim], source, None, nadj, lpls, useGNN) 156 | if self.model == 'DMSTGCN': 157 | x_predic = self.predictor(x_prompt_return, source[:, 0, 0, 1], select_dataset) # MTGNN 158 | else: 159 | x_predic = self.predictor(x_prompt_return, select_dataset) # STGCN 160 | return x_predic, x_prompt_return 161 | 162 | def forward_ori(self, source, label=None, select_dataset=None, batch_seen=None): 163 | if self.model == 'DMSTGCN': 164 | x_predic = self.predictor(source[..., :self.input_base_dim], source[:, 0, 0, 1], select_dataset) # MTGNN 165 | else: 166 | x_predic = self.predictor(source[..., :self.input_base_dim], select_dataset) 167 | return x_predic, None 168 | -------------------------------------------------------------------------------- /model/GWN/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pandas as pd 4 | import configparser 5 | from lib.predifineGraph import get_adjacency_matrix, load_pickle, weight_matrix 6 | import torch 7 | 8 | def parse_args(DATASET): 9 | # get configuration 10 | config_file = '../conf/GWN/{}.conf'.format(DATASET) 11 | config = configparser.ConfigParser() 12 | config.read(config_file) 13 | 14 | parser = argparse.ArgumentParser() 15 | # general 16 | parser.add_argument('--device', type=str, default=config['general']['device']) 17 | 18 | # data 19 | parser.add_argument('--num_nodes', type=int, default=config['data']['num_nodes']) 20 | parser.add_argument('--input_window', type=int, default=config['data']['input_window']) 21 | parser.add_argument('--output_window', type=int, default=config['data']['output_window']) 22 | parser.add_argument('--output_dim', type=int, default=config['data']['output_dim']) 23 | # model 24 | parser.add_argument('--dropout', type=float, default=config['model']['dropout']) 25 | parser.add_argument('--blocks', type=int, default=config['model']['blocks']) 26 | parser.add_argument('--layers', type=int, default=config['model']['layers']) 27 | parser.add_argument('--gcn_bool', type=eval, default=config['model']['gcn_bool']) 28 | parser.add_argument('--addaptadj', type=eval, default=config['model']['addaptadj']) 29 | parser.add_argument('--adjtype', type=str, default=config['model']['adjtype']) 30 | parser.add_argument('--randomadj', type=eval, default=config['model']['randomadj']) 31 | parser.add_argument('--aptonly', type=eval, default=config['model']['aptonly']) 32 | parser.add_argument('--kernel_size', type=int, default=config['model']['kernel_size']) 33 | parser.add_argument('--nhid', type=int, default=config['model']['nhid']) 34 | parser.add_argument('--residual_channels', type=int, default=config['model']['residual_channels']) 35 | parser.add_argument('--dilation_channels', type=int, default=config['model']['dilation_channels']) 36 | 37 | args, _ = parser.parse_known_args() 38 | args.adj_mx = None 39 | return args -------------------------------------------------------------------------------- /model/MSDR/args.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.sparse import linalg 3 | import scipy.sparse as sp 4 | import pandas as pd 5 | import torch 6 | import configparser 7 | import argparse 8 | from lib.predifineGraph import get_adjacency_matrix, load_pickle, weight_matrix 9 | 10 | def calculate_normalized_laplacian(adj): 11 | """ 12 | # L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2 13 | # D = diag(A 1) 14 | :param adj: 15 | :return: 16 | """ 17 | adj = sp.coo_matrix(adj) 18 | d = np.array(adj.sum(1)) 19 | d_inv_sqrt = np.power(d, -0.5).flatten() 20 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 21 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 22 | normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 23 | return normalized_laplacian 24 | 25 | 26 | def calculate_random_walk_matrix(adj_mx): 27 | adj_mx = sp.coo_matrix(adj_mx) 28 | d = np.array(adj_mx.sum(1)) 29 | d_inv = np.power(d, -1).flatten() 30 | d_inv[np.isinf(d_inv)] = 0. 31 | d_mat_inv = sp.diags(d_inv) 32 | random_walk_mx = d_mat_inv.dot(adj_mx).tocoo() 33 | return random_walk_mx 34 | 35 | 36 | def calculate_reverse_random_walk_matrix(adj_mx): 37 | return calculate_random_walk_matrix(np.transpose(adj_mx)) 38 | 39 | 40 | def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True): 41 | if undirected: 42 | print(adj_mx) 43 | adj_mx = np.maximum.reduce([adj_mx, adj_mx.T]) 44 | L = calculate_normalized_laplacian(adj_mx) 45 | if lambda_max is None: 46 | lambda_max, _ = linalg.eigsh(L, 1, which='LM') 47 | lambda_max = lambda_max[0] 48 | L = sp.csr_matrix(L) 49 | M, _ = L.shape 50 | I = sp.identity(M, format='csr', dtype=L.dtype) 51 | L = (2 / lambda_max * L) - I 52 | return L.astype(np.float32) 53 | 54 | def get_adjacency_matrix(distance_df_filename, num_of_vertices, type_='connectivity', id_filename=None): 55 | A = np.zeros((int(num_of_vertices), int(num_of_vertices)), dtype=np.float32) 56 | 57 | if id_filename: 58 | with open(id_filename, 'r') as f: 59 | id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 建立映射列表 60 | df = pd.read_csv(distance_df_filename) 61 | for row in df.values: 62 | if len(row) != 3: 63 | continue 64 | i, j = int(row[0]), int(row[1]) 65 | A[id_dict[i], id_dict[j]] = 1 66 | A[id_dict[j], id_dict[i]] = 1 67 | 68 | return A 69 | df = pd.read_csv(distance_df_filename) 70 | for row in df.values: 71 | if len(row) != 3: 72 | continue 73 | i, j, distance = int(row[0]), int(row[1]), float(row[2]) 74 | if type_ == 'connectivity': 75 | A[i, j] = 1 76 | A[j, i] = 1 77 | elif type == 'distance': 78 | A[i, j] = 1 / distance 79 | A[j, i] = 1 / distance 80 | else: 81 | raise ValueError("type_ error, must be " 82 | "connectivity or distance!") 83 | 84 | return A 85 | 86 | def parse_args(DATASET): 87 | # get configuration 88 | config_file = '../conf/MSDR/{}.conf'.format(DATASET) 89 | config = configparser.ConfigParser() 90 | config.read(config_file) 91 | 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument('--filter_type', default=config['model']['filter_type'], type=str) 94 | parser.add_argument('--data', default=DATASET, help='data path', type=str, ) 95 | parser.add_argument('--cl_decay_steps', type=int, default=config['model']['cl_decay_steps']) 96 | parser.add_argument('--num_nodes', type=int, default=config['model']['num_nodes']) 97 | parser.add_argument('--horizon', type=int, default=config['model']['horizon']) 98 | parser.add_argument('--seq_len', type=int, default=config['model']['seq_len']) 99 | parser.add_argument('--max_diffusion_step', type=int, default=config['model']['max_diffusion_step']) 100 | parser.add_argument('--num_rnn_layers', type=int, default=config['model']['num_rnn_layers']) 101 | parser.add_argument('--output_dim', type=int, default=config['model']['output_dim']) 102 | parser.add_argument('--rnn_units', type=int, default=config['model']['rnn_units']) 103 | parser.add_argument('--pre_k', type=int, default=config['model']['pre_k']) 104 | parser.add_argument('--pre_v', type=int, default=config['model']['pre_v']) 105 | parser.add_argument('--use_curriculum_learning', type=eval, default=config['model']['use_curriculum_learning']) 106 | parser.add_argument('--construct_type', type=str, default=config['model']['construct_type']) 107 | parser.add_argument('--l2lambda', type=int, default=config['model']['l2lambda']) 108 | 109 | args, _ = parser.parse_known_args() 110 | args.adj_mx = None 111 | return args 112 | 113 | -------------------------------------------------------------------------------- /model/MSDR/gmsdr_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .gmsdr_cell import GMSDRCell 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | def count_parameters(model): 11 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 12 | 13 | 14 | class Seq2SeqAttrs: 15 | def __init__(self, args): 16 | self.max_diffusion_step = args.max_diffusion_step 17 | self.cl_decay_steps = args.cl_decay_steps 18 | self.filter_type = args.filter_type 19 | # self.num_nodes = args.num_nodes 20 | self.num_rnn_layers = args.num_rnn_layers 21 | self.rnn_units = args.rnn_units 22 | # self.hidden_state_size = self.num_nodes * self.rnn_units 23 | self.pre_k = args.pre_k 24 | self.pre_v = args.pre_v 25 | self.input_dim = args.input_dim 26 | self.output_dim = args.output_dim 27 | self.A_dict = args.A_dict 28 | self.dataset_use = args.dataset_use 29 | self.dataset_test = args.dataset_test 30 | self.mode = args.mode 31 | 32 | 33 | class EncoderModel(nn.Module, Seq2SeqAttrs): 34 | def __init__(self, args): 35 | nn.Module.__init__(self) 36 | Seq2SeqAttrs.__init__(self, args) 37 | self.input_dim = args.input_dim 38 | self.seq_len = args.seq_len # for the encoder 39 | self.mlp = nn.Linear(self.input_dim, self.rnn_units) 40 | self.gmsdr_layers = nn.ModuleList( 41 | [GMSDRCell(self.rnn_units, self.input_dim, self.max_diffusion_step, self.pre_k, self.pre_v, 42 | self.A_dict, self.dataset_use, self.dataset_test, self.mode, 43 | filter_type=self.filter_type) for _ in range(self.num_rnn_layers)]) 44 | 45 | def forward(self, inputs, hx_k, select_dataset): 46 | """ 47 | Encoder forward pass. 48 | 49 | :param inputs: shape (batch_size, self.num_nodes * self.input_dim) 50 | :param hx_k: (num_layers, batch_size, pre_k, self.num_nodes, self.rnn_units) 51 | optional, zeros if not provided 52 | :return: output: # shape (batch_size, self.hidden_state_size) 53 | hx_k # shape (num_layers, batch_size, pre_k, self.num_nodes, self.rnn_units) 54 | (lower indices mean lower layers) 55 | """ 56 | hx_ks = [] 57 | batch = inputs.shape[0] 58 | x = inputs.reshape(batch, -1, self.input_dim) 59 | output = self.mlp(x).view(batch, -1) 60 | for layer_num, dcgru_layer in enumerate(self.gmsdr_layers): 61 | next_hidden_state, new_hx_k = dcgru_layer(output, hx_k[layer_num], select_dataset) 62 | hx_ks.append(new_hx_k) 63 | output = next_hidden_state 64 | return output, torch.stack(hx_ks) 65 | 66 | 67 | class DecoderModel(nn.Module, Seq2SeqAttrs): 68 | def __init__(self, args): 69 | nn.Module.__init__(self) 70 | Seq2SeqAttrs.__init__(self, args) 71 | self.output_dim = args.output_dim 72 | self.horizon = args.horizon # for the decoder 73 | self.projection_layer = nn.Linear(self.rnn_units, self.output_dim) 74 | self.gmsdr_layers = nn.ModuleList( 75 | [GMSDRCell(self.rnn_units, self.rnn_units, self.max_diffusion_step, self.pre_k, self.pre_v, 76 | self.A_dict, self.dataset_use, self.dataset_test, self.mode, 77 | filter_type=self.filter_type) for _ in range(self.num_rnn_layers)]) 78 | 79 | def forward(self, inputs, hx_k, select_dataset): 80 | """ 81 | Decoder forward pass. 82 | 83 | :param inputs: shape (batch_size, self.num_nodes * self.output_dim) 84 | :param hx_k: (num_layers, batch_size, pre_k, num_nodes, rnn_units) 85 | optional, zeros if not provided 86 | :return: output: # shape (batch_size, self.num_nodes * self.output_dim) 87 | hidden_state # shape (num_layers, batch_size, self.hidden_state_size) 88 | (lower indices mean lower layers) 89 | """ 90 | hx_ks = [] 91 | output = inputs 92 | for layer_num, dcgru_layer in enumerate(self.gmsdr_layers): 93 | next_hidden_state, new_hx_k = dcgru_layer(output, hx_k[layer_num], select_dataset) 94 | hx_ks.append(new_hx_k) 95 | output = next_hidden_state 96 | 97 | projected = self.projection_layer(output.view(-1, self.rnn_units)) 98 | num_nodes = self.A_dict[select_dataset].shape[0] 99 | output = projected.view(-1, num_nodes * self.output_dim) 100 | 101 | return output, torch.stack(hx_ks) 102 | 103 | 104 | class GMSDRModel(nn.Module, Seq2SeqAttrs): 105 | def __init__(self, args): 106 | super().__init__() 107 | Seq2SeqAttrs.__init__(self, args) 108 | self.encoder_model = EncoderModel(args) 109 | self.decoder_model = DecoderModel(args) 110 | self.cl_decay_steps = args.cl_decay_steps 111 | self.use_curriculum_learning = args.use_curriculum_learning 112 | # self._logger = logger 113 | self.out = nn.Linear(self.rnn_units, self.decoder_model.output_dim) 114 | 115 | def _compute_sampling_threshold(self, batches_seen): 116 | return self.cl_decay_steps / ( 117 | self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps)) 118 | 119 | def encoder(self, inputs, select_dataset): 120 | """ 121 | encoder forward pass on t time steps 122 | :param inputs: shape (seq_len, batch_size, num_sensor * input_dim) 123 | :return: hx_k: (num_layers, batch_size, pre_k, num_sensor, rnn_units) 124 | """ 125 | num_nodes = inputs.shape[2] 126 | hx_k = torch.zeros(self.num_rnn_layers, inputs.shape[1], self.pre_k, num_nodes, self.rnn_units, 127 | device=device) 128 | outputs = [] 129 | for t in range(self.encoder_model.seq_len): 130 | output, hx_k = self.encoder_model(inputs[t], hx_k, select_dataset) 131 | outputs.append(output) 132 | return torch.stack(outputs), hx_k 133 | 134 | def decoder(self, inputs, hx_k, select_dataset, labels=None, batches_seen=None): 135 | """ 136 | Decoder forward pass 137 | :param inputs: (seq_len, batch_size, num_sensor * rnn_units) 138 | :param hx_k: (num_layers, batch_size, pre_k, num_sensor, rnn_units) 139 | :param labels: (self.horizon, batch_size, self.num_nodes * self.output_dim) [optional, not exist for inference] 140 | :param batches_seen: global step [optional, not exist for inference] 141 | :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) 142 | """ 143 | decoder_hx_k = hx_k 144 | decoder_input = inputs 145 | 146 | outputs = [] 147 | for t in range(self.decoder_model.horizon): 148 | decoder_output, decoder_hx_k = self.decoder_model(decoder_input[t], 149 | decoder_hx_k, select_dataset) 150 | outputs.append(decoder_output) 151 | outputs = torch.stack(outputs) 152 | return outputs 153 | 154 | 155 | def forward(self, inputs, select_dataset, labels=None, batches_seen=None): 156 | """ 157 | seq2seq forward pass 158 | :param inputs: shape (seq_len, batch_size, num_sensor * input_dim) 159 | :param labels: shape (horizon, batch_size, num_sensor * output) 160 | :param batches_seen: batches seen till now 161 | :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) 162 | """ 163 | inputs = inputs.transpose(0, 1) 164 | encoder_outputs, hx_k = self.encoder(inputs, select_dataset) 165 | # self._logger.debug("Encoder complete, starting decoder") 166 | outputs = self.decoder(encoder_outputs, hx_k, select_dataset, labels=None, batches_seen=batches_seen) 167 | # self._logger.debug("Decoder complete") 168 | # if batches_seen == 0: 169 | # self._logger.info( 170 | # "Total trainable parameters {}".format(count_parameters(self)) 171 | # ) 172 | 173 | if self.decoder_model.output_dim == 1: 174 | outputs = outputs.transpose(1, 0).unsqueeze(-1) 175 | else: 176 | time_step, batch_size = outputs.shape[0], outputs.shape[1] 177 | outputs = outputs.transpose(1, 0).unsqueeze(-1).reshape(batch_size, time_step, -1, self.decoder_model.output_dim) 178 | return outputs 179 | -------------------------------------------------------------------------------- /model/MTGNN/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pandas as pd 4 | import configparser 5 | from lib.predifineGraph import get_adjacency_matrix, load_pickle, weight_matrix 6 | import torch 7 | 8 | def parse_args(DATASET): 9 | # get configuration 10 | config_file = '../conf/MTGNN/{}.conf'.format(DATASET) 11 | config = configparser.ConfigParser() 12 | config.read(config_file) 13 | 14 | parser = argparse.ArgumentParser() 15 | # general 16 | parser.add_argument('--device', type=str, default=config['general']['device']) 17 | 18 | # data 19 | parser.add_argument('--num_nodes', type=int, default=config['data']['num_nodes']) 20 | parser.add_argument('--input_window', type=int, default=config['data']['input_window']) 21 | parser.add_argument('--output_window', type=int, default=config['data']['output_window']) 22 | parser.add_argument('--output_dim', type=int, default=config['data']['output_dim']) 23 | # model 24 | parser.add_argument('--gcn_true', type=eval, default=config['model']['gcn_true']) 25 | parser.add_argument('--buildA_true', type=eval, default=config['model']['buildA_true']) 26 | parser.add_argument('--gcn_depth', type=int, default=config['model']['gcn_depth']) 27 | parser.add_argument('--dropout', type=float, default=config['model']['dropout']) 28 | parser.add_argument('--subgraph_size', type=int, default=config['model']['subgraph_size']) 29 | parser.add_argument('--node_dim', type=int, default=config['model']['node_dim']) 30 | parser.add_argument('--dilation_exponential', type=int, default=config['model']['dilation_exponential']) 31 | parser.add_argument('--conv_channels', type=int, default=config['model']['conv_channels']) 32 | parser.add_argument('--residual_channels', type=int, default=config['model']['residual_channels']) 33 | parser.add_argument('--skip_channels', type=int, default=config['model']['skip_channels']) 34 | parser.add_argument('--end_channels', type=int, default=config['model']['end_channels']) 35 | parser.add_argument('--layers', type=int, default=config['model']['layers']) 36 | parser.add_argument('--propalpha', type=float, default=config['model']['propalpha']) 37 | parser.add_argument('--tanhalpha', type=int, default=config['model']['tanhalpha']) 38 | parser.add_argument('--layer_norm_affline', type=eval, default=config['model']['layer_norm_affline']) 39 | parser.add_argument('--use_curriculum_learning', type=eval, default=config['model']['use_curriculum_learning']) 40 | parser.add_argument('--task_level', type=int, default=config['model']['task_level']) 41 | 42 | args, _ = parser.parse_known_args() 43 | args.adj_mx = None 44 | return args -------------------------------------------------------------------------------- /model/PDFormer/args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from fastdtw import fastdtw 4 | from tqdm import tqdm 5 | from tslearn.clustering import TimeSeriesKMeans, KShape 6 | from lib.data_process import load_st_dataset 7 | import argparse 8 | import configparser 9 | 10 | def split_data_by_ratio(data, val_ratio, test_ratio): 11 | data_len = data.shape[0] 12 | test_data = data[-int(data_len*test_ratio):] 13 | val_data = data[-int(data_len*(test_ratio+val_ratio)):-int(data_len*test_ratio)] 14 | train_data = data[:-int(data_len*(test_ratio+val_ratio))] 15 | return train_data, val_data, test_data 16 | 17 | def Add_Window_Horizon(data, window=3, horizon=1, single=False): 18 | ''' 19 | :param data: shape [B, ...] 20 | :param window: 21 | :param horizon: 22 | :return: X is [B, W, ...], Y is [B, H, ...] 23 | ''' 24 | length = len(data) 25 | end_index = length - horizon - window + 1 26 | X = [] #windows 27 | Y = [] #horizon 28 | index = 0 29 | if single: 30 | while index < end_index: 31 | X.append(data[index:index+window]) 32 | Y.append(data[index+window+horizon-1:index+window+horizon]) 33 | index = index + 1 34 | else: 35 | while index < end_index: 36 | X.append(data[index:index+window]) 37 | Y.append(data[index+window:index+window+horizon]) 38 | index = index + 1 39 | X = np.array(X) 40 | Y = np.array(Y) 41 | return X, Y 42 | 43 | def get_dtw(df, args, dataset, num_nodes): 44 | filename = dataset 45 | cache_path = f'../data/PDFormer/{filename}/{filename}_dtw.npy' 46 | if not os.path.exists(cache_path): 47 | data_mean = np.mean( 48 | [df[24 * args.points_per_hour * i: 24 * args.points_per_hour * (i + 1)] 49 | for i in range(df.shape[0] // (24 * args.points_per_hour))], axis=0) 50 | dtw_distance = np.zeros((num_nodes, num_nodes)) 51 | for i in tqdm(range(num_nodes)): 52 | for j in range(i, num_nodes): 53 | dtw_distance[i][j], _ = fastdtw(data_mean[:, i, :], data_mean[:, j, :], radius=6) 54 | for i in range(num_nodes): 55 | for j in range(i): 56 | dtw_distance[i][j] = dtw_distance[j][i] 57 | np.save(cache_path, dtw_distance) 58 | dtw_matrix = np.load(cache_path) 59 | print('Load DTW matrix from {}'.format(cache_path)) 60 | return dtw_matrix 61 | 62 | def get_pattern_key(x_train, args, dataset, args_base): 63 | filename = dataset 64 | cache_path = f'../data/PDFormer/{filename}/{filename}_pattern.npy' 65 | if not os.path.exists(cache_path): 66 | cand_key_time_steps = args.cand_key_days * args.points_per_day 67 | pattern_cand_keys = x_train[:cand_key_time_steps, :args.s_attn_size, :, :args_base.output_dim].swapaxes(1,2).reshape( 68 | -1, args.s_attn_size, args_base.output_dim) 69 | print("Clustering...") 70 | if args.cluster_method == "kshape": 71 | km = KShape(n_clusters=args.n_cluster, max_iter=args.cluster_max_iter).fit(pattern_cand_keys) 72 | else: 73 | km = TimeSeriesKMeans(n_clusters=args.n_cluster, metric="softdtw", max_iter=args.cluster_max_iter).fit( 74 | pattern_cand_keys) 75 | pattern_keys = km.cluster_centers_ 76 | np.save(cache_path, pattern_keys) 77 | print("Saved at file " + cache_path) 78 | pattern_key_matrix = np.load(cache_path) 79 | return pattern_key_matrix 80 | 81 | def load_rel(adj_mx, args, dataset): 82 | filename = dataset 83 | cache_path = f'../data/PDFormer/{filename}/{filename}_sh_mx.npy' 84 | sh_mx = adj_mx.copy() 85 | if args.type_short_path == 'hop': 86 | if not os.path.exists(cache_path): 87 | print('Max adj_mx value = {}'.format(adj_mx.max())) 88 | num_nodes = adj_mx.shape[0] 89 | sh_mx[sh_mx > 0] = 1 90 | sh_mx[sh_mx == 0] = 511 91 | for i in range(num_nodes): 92 | sh_mx[i, i] = 0 93 | for k in range(num_nodes): 94 | for i in range(num_nodes): 95 | for j in range(num_nodes): 96 | sh_mx[i, j] = min(sh_mx[i, j], sh_mx[i, k] + sh_mx[k, j], 511) 97 | np.save(cache_path, sh_mx) 98 | sh_mx = np.load(cache_path) 99 | return sh_mx 100 | 101 | 102 | def parse_args(DATASET, args_base): 103 | # get configuration 104 | print(DATASET) 105 | config_file = '../conf/PDFormer/{}.conf'.format(DATASET) 106 | config = configparser.ConfigParser() 107 | config.read(config_file) 108 | 109 | parser = argparse.ArgumentParser() 110 | # model 111 | parser.add_argument('--embed_dim', type=int, default=config['model']['embed_dim']) 112 | parser.add_argument('--skip_dim', type=int, default=config['model']['skip_dim']) 113 | parser.add_argument('--lape_dim', type=int, default=config['model']['lape_dim']) 114 | 115 | parser.add_argument('--geo_num_heads', type=int, default=config['model']['geo_num_heads']) 116 | parser.add_argument('--sem_num_heads', type=int, default=config['model']['sem_num_heads']) 117 | parser.add_argument('--t_num_heads', type=int, default=config['model']['t_num_heads']) 118 | parser.add_argument('--mlp_ratio', type=int, default=config['model']['mlp_ratio']) 119 | parser.add_argument('--qkv_bias', type=eval, default=config['model']['qkv_bias']) 120 | parser.add_argument('--drop', type=float, default=config['model']['drop']) 121 | parser.add_argument('--attn_drop', type=float, default=config['model']['attn_drop']) 122 | parser.add_argument('--drop_path', type=float, default=config['model']['drop_path']) 123 | parser.add_argument('--s_attn_size', type=int, default=config['model']['s_attn_size']) 124 | parser.add_argument('--t_attn_size', type=int, default=config['model']['t_attn_size']) 125 | parser.add_argument('--enc_depth', type=int, default=config['model']['enc_depth']) 126 | parser.add_argument('--type_ln', type=str, default=config['model']['type_ln']) 127 | parser.add_argument('--type_short_path', type=str, default=config['model']['type_short_path']) 128 | parser.add_argument('--add_time_in_day', type=eval, default=config['model']['add_time_in_day']) 129 | parser.add_argument('--add_day_in_week', type=eval, default=config['model']['add_day_in_week']) 130 | 131 | parser.add_argument('--far_mask_delta', type=int, default=config['model']['far_mask_delta']) 132 | parser.add_argument('--dtw_delta', type=int, default=config['model']['dtw_delta']) 133 | parser.add_argument('--time_intervals', type=int, default=config['model']['time_intervals']) 134 | parser.add_argument('--cand_key_days', type=int, default=config['model']['cand_key_days']) 135 | parser.add_argument('--n_cluster', type=int, default=config['model']['n_cluster']) 136 | parser.add_argument('--cluster_max_iter', type=int, default=config['model']['cluster_max_iter']) 137 | parser.add_argument('--cluster_method', type=str, default=config['model']['cluster_method']) 138 | 139 | # self.s_attn_size = config.get("s_attn_size", 3) 140 | # self.n_cluster = config.get("n_cluster", 16) 141 | # self.cluster_max_iter = config.get("cluster_max_iter", 5) 142 | # self.cluster_method = config.get("cluster_method", "kshape") 143 | 144 | args_predictor, _ = parser.parse_known_args() 145 | 146 | args_predictor.points_per_hour = 3600 // args_predictor.time_intervals 147 | args_predictor.points_per_day = 24 * 3600 // args_predictor.time_intervals 148 | 149 | if args_base.mode == 'pretrain': 150 | data_list = args_base.dataset_use 151 | else: 152 | data_list = [args_base.dataset_test] 153 | dtw_matrix_dict = {} 154 | sh_mx_dict = {} 155 | pattern_key_matrix_dict = {} 156 | adj_mx_dict = {} 157 | 158 | for i, data_graph in enumerate(data_list): 159 | data = load_st_dataset(data_graph, args_base) 160 | data = data[..., 0:args_base.input_base_dim] 161 | data_train, data_val, data_test = split_data_by_ratio(data, args_base.val_ratio, args_base.test_ratio) 162 | x_tra, y_tra = Add_Window_Horizon(data_train, args_base.his, args_base.pred) 163 | num_nodes = args_base.A_dict_np[data_graph].shape[0] 164 | 165 | dtw_matrix = get_dtw(data, args_predictor, data_graph, num_nodes) 166 | dtw_matrix_dict[data_graph] = dtw_matrix 167 | sh_mx = load_rel(args_base.A_dict_np[data_graph], args_predictor, data_graph) 168 | sh_mx_dict[data_graph] = sh_mx 169 | pattern_key_matrix = get_pattern_key(x_tra, args_predictor, data_graph, args_base) 170 | pattern_key_matrix_dict[data_graph] = pattern_key_matrix 171 | args_predictor.dtw_matrix_dict = dtw_matrix_dict 172 | args_predictor.adj_mx_dict = args_base.A_dict_np 173 | args_predictor.sd_mx = None 174 | args_predictor.sh_mx_dict = sh_mx_dict 175 | args_predictor.pattern_key_matrix_dict = pattern_key_matrix_dict 176 | args_predictor.lap_mx_dict = args_base.lpls_dict 177 | # print(sss) 178 | return args_predictor -------------------------------------------------------------------------------- /model/PromptNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class GCN(nn.Module): 8 | def __init__(self, hidden_dim) -> None: 9 | super().__init__() 10 | self.fc1 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True) 11 | self.act = nn.LeakyReLU() 12 | 13 | def forward(self, input_data: torch.Tensor, nadj: torch.Tensor, useGNN=False) -> torch.Tensor: 14 | if useGNN: 15 | gcn_out = self.act(torch.einsum('nk,bdke->bdne', nadj, self.fc1(input_data))) 16 | # gcn_out = self.act(self.fc1(torch.einsum('nk,bdke->bdne', nadj, input_data))) 17 | else: 18 | gcn_out = self.act(self.fc1(input_data)) 19 | return gcn_out + input_data 20 | 21 | 22 | class MultiLayerPerceptron(nn.Module): 23 | 24 | def __init__(self, input_dim, hidden_dim) -> None: 25 | super().__init__() 26 | self.fc1 = nn.Conv2d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True) 27 | self.fc2 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True) 28 | self.act = nn.ReLU() 29 | self.drop = nn.Dropout(p=0.15) 30 | 31 | def forward(self, input_data: torch.Tensor) -> torch.Tensor: 32 | 33 | hidden = self.fc2(self.drop(self.act(self.fc1(input_data)))) 34 | hidden = hidden + input_data 35 | return hidden 36 | 37 | class PromptNet(nn.Module): 38 | def __init__(self, args): 39 | super(PromptNet, self).__init__() 40 | 41 | self.mode = args.mode 42 | self.node_dim = args.node_dim 43 | self.input_len = args.his 44 | self.embed_dim = args.embed_dim 45 | self.output_len = args.pred 46 | self.num_layer = args.num_layer 47 | self.temp_dim_tid = args.temp_dim_tid 48 | self.temp_dim_diw = args.temp_dim_diw 49 | 50 | self.if_time_in_day = args.if_time_in_day 51 | self.if_day_in_week = args.if_day_in_week 52 | self.if_spatial = args.if_spatial 53 | 54 | self.input_base_dim = args.input_base_dim 55 | self.input_extra_dim = args.input_extra_dim 56 | self.data_type = args.data_type 57 | 58 | # spatial embeddings 59 | if self.if_spatial: 60 | self.LaplacianPE1 = nn.Linear(self.node_dim, self.node_dim) 61 | self.LaplacianPE2 = nn.Linear(self.node_dim, self.node_dim) 62 | 63 | # temporal embeddings 64 | if self.if_time_in_day: 65 | self.time_in_day_emb = nn.Embedding(288+1, self.temp_dim_tid) 66 | if self.if_day_in_week: 67 | self.day_in_week_emb = nn.Embedding(7+1, self.temp_dim_diw) 68 | 69 | self.time_series_emb_layer = nn.Linear(self.input_len, self.embed_dim, bias=True) 70 | 71 | self.hidden_dim = self.embed_dim+self.node_dim * \ 72 | int(self.if_spatial)+self.temp_dim_tid*int(self.if_day_in_week) + \ 73 | self.temp_dim_diw*int(self.if_time_in_day) 74 | 75 | # Base 76 | self.encoder1 = nn.Sequential( 77 | *[MultiLayerPerceptron(self.hidden_dim, self.hidden_dim) for _ in range(self.num_layer)]) 78 | self.encoder2 = nn.Sequential( 79 | *[MultiLayerPerceptron(self.hidden_dim, self.hidden_dim) for _ in range(self.num_layer)]) 80 | self.gcn1 = GCN(self.hidden_dim) 81 | self.gcn2 = GCN(self.hidden_dim) 82 | 83 | self.act = nn.LeakyReLU() 84 | 85 | def forward(self, history_data, source2, batch_seen=None, nadj=None, lpls=None, useGNN=False): 86 | 87 | input_data = history_data 88 | batch_size, _, num_nodes, _ = input_data.shape 89 | 90 | ZERO = torch.IntTensor(1).to('cuda:0') 91 | if self.if_time_in_day: 92 | t_i_d_data = source2[:, 0, :, self.input_base_dim] 93 | time_in_day_emb = self.time_in_day_emb(t_i_d_data[:, :].type_as(ZERO)) 94 | else: 95 | time_in_day_emb = None 96 | if self.if_day_in_week: 97 | d_i_w_data = source2[:, 0, :, self.input_base_dim+1] 98 | day_in_week_emb = self.day_in_week_emb(d_i_w_data[:, :].type_as(ZERO)) 99 | else: 100 | day_in_week_emb = None 101 | 102 | time_series_emb = self.time_series_emb_layer(input_data[..., 0:self.input_base_dim].transpose(1, 3)) 103 | 104 | node_emb = [] 105 | if self.if_spatial: 106 | lap_pos_enc = self.LaplacianPE2(self.act(self.LaplacianPE1(lpls))) 107 | tensor_neb = lap_pos_enc.unsqueeze(0).expand(batch_size, -1, -1).unsqueeze(1).repeat(1, self.input_base_dim, 1, 1) 108 | node_emb.append(tensor_neb) 109 | 110 | # temporal embeddings 111 | tem_emb = [] 112 | if time_in_day_emb is not None: 113 | tem_emb.append(time_in_day_emb.unsqueeze(1)) 114 | if day_in_week_emb is not None: 115 | tem_emb.append(day_in_week_emb.unsqueeze(1)) 116 | 117 | # concate all embeddings 118 | hidden = torch.cat([time_series_emb] + node_emb + tem_emb, dim=-1).transpose(1, 3) 119 | 120 | # encoding 121 | hidden_gcn = self.gcn1(hidden, nadj, useGNN) 122 | hidden = self.encoder1(hidden_gcn) 123 | hidden_gcn = self.gcn2(hidden, nadj, useGNN) 124 | hidden = self.encoder2(hidden_gcn) 125 | x_prompt = hidden.transpose(1, 3) + input_data[..., 0:self.input_base_dim] 126 | x_prompt = F.normalize(x_prompt, dim=-1) 127 | return x_prompt 128 | 129 | -------------------------------------------------------------------------------- /model/Run.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | file_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 5 | print(file_dir) 6 | sys.path.append(file_dir) 7 | 8 | import torch 9 | import numpy as np 10 | import torch.nn as nn 11 | import argparse 12 | import configparser 13 | from datetime import datetime 14 | from Trainer import Trainer 15 | from FlashST import FlashST as Network_Pretrain 16 | from FlashST import FlashST as Network_Predict 17 | from lib.TrainInits import init_seed 18 | from lib.TrainInits import print_model_parameters 19 | from lib.metrics import MAE_torch, MSE_torch, huber_loss 20 | from lib.predifineGraph import * 21 | from lib.data_process import define_dataloder, get_val_tst_dataloader, data_type_init 22 | from conf.FlashST.Params_pretrain import parse_args 23 | import torch.nn.functional as F 24 | 25 | # *************************************************************************# 26 | # mode = 'eval' # pretrain eval ori test 27 | # dataset_test = ['CA_District5'] # NYC_BIKE, CA_District5, PEMS07M, chengdu_didi 28 | # dataset_use = ['PEMS08', 'PEMS04', 'PEMS07', 'PEMS03'] # PEMS08, PEMS04, PEMS07, PEMS03 29 | # model = 'STGCN' # TGCN STGCN ASTGCN GWN STSGCN AGCRN MTGNN STFGNN STGODE DMSTGCN MSDR STWA PDFormer 30 | 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | args = parse_args(device) 33 | 34 | print('Mode: ', args.mode, ' model: ', args.model, ' DATASET: ', args.dataset_test, 35 | ' load_pretrain_path: ', args.load_pretrain_path, ' save_pretrain_path: ', args.save_pretrain_path) 36 | 37 | 38 | def Mkdir(path): 39 | if os.path.isdir(path): 40 | pass 41 | else: 42 | os.makedirs(path) 43 | 44 | def infoNCEloss(): 45 | def loss(q, k): 46 | T = 0.3 47 | pos_sim = torch.sum(torch.mul(q, q), dim=-1) 48 | neg_sim = torch.matmul(q, q.transpose(-1, -2)) 49 | pos = torch.exp(torch.div(pos_sim, T)) 50 | neg = torch.sum(torch.exp(torch.div(neg_sim, T)), dim=-1) 51 | denominator = neg + pos 52 | return torch.mean(-torch.log(torch.div(pos, denominator))) 53 | return loss 54 | 55 | def scaler_mae_loss(mask_value): 56 | def loss(preds, labels, scaler, mask=None): 57 | if scaler: 58 | preds = scaler.inverse_transform(preds) 59 | labels = scaler.inverse_transform(labels) 60 | mae, mae_loss = MAE_torch(pred=preds, true=labels, mask_value=mask_value) 61 | # print(mae.shape, mae_loss.shape) 62 | return mae, mae_loss 63 | return loss 64 | 65 | def scaler_huber_loss(mask_value): 66 | def loss(preds, labels, scaler, mask=None): 67 | if scaler: 68 | preds = scaler.inverse_transform(preds) 69 | labels = scaler.inverse_transform(labels) 70 | mae, mae_loss = huber_loss(pred=preds, true=labels, mask_value=mask_value) 71 | # print(mae.shape, mae_loss.shape) 72 | return mae, mae_loss 73 | return loss 74 | 75 | if args.model == 'GWN' or args.model == 'MTGNN' or args.model == 'STFGNN' or args.model == 'STGODE' or args.model == 'DMSTGCN': 76 | seed_mode = False # for quick running 77 | else: 78 | seed_mode = True 79 | init_seed(args.seed, seed_mode) 80 | 81 | #config log path 82 | current_dir = os.path.dirname(os.path.realpath(__file__)) 83 | log_dir = os.path.join(current_dir, '../SAVE', args.mode, args.model) 84 | Mkdir(log_dir) 85 | args.log_dir = log_dir 86 | 87 | #predefine Graph 88 | dataset_graph = [] 89 | if args.mode == 'pretrain': 90 | dataset_graph = args.dataset_use.copy() 91 | else: 92 | dataset_graph.append(args.dataset_test) 93 | args.dataset_graph = dataset_graph 94 | pre_graph_dict(args) 95 | data_type_init(args.dataset_test, args) 96 | 97 | if args.model == 'STGODE' or args.model == 'AGCRN' or args.model == 'ASTGCN': 98 | xavier = True 99 | else: 100 | xavier = False 101 | 102 | args.xavier = xavier 103 | 104 | #load dataset 105 | if args.mode == 'pretrain': 106 | x_trn_dict, y_trn_dict, _, _, _, _, scaler_dict = define_dataloder(stage='Train', args=args) 107 | eval_train_loader, eval_val_loader, eval_test_loader, eval_scaler_dict = None, None, None, None 108 | else: 109 | x_trn_dict, y_trn_dict, scaler_dict = None, None, None 110 | eval_x_trn_dict, eval_y_trn_dict, eval_x_val_dict, eval_y_val_dict, eval_x_tst_dict, eval_y_tst_dict, eval_scaler_dict = define_dataloder(stage='eval', args=args) 111 | eval_train_loader = get_val_tst_dataloader(eval_x_trn_dict, eval_y_trn_dict, args, shuffle=True) 112 | eval_val_loader = get_val_tst_dataloader(eval_x_val_dict, eval_y_val_dict, args, shuffle=False) 113 | eval_test_loader = get_val_tst_dataloader(eval_x_tst_dict, eval_y_tst_dict, args, shuffle=False) 114 | 115 | 116 | #init model 117 | if args.mode == 'pretrain': 118 | model = Network_Pretrain(args) 119 | # if torch.cuda.device_count() > 1: 120 | # model = nn.DataParallel(model) 121 | model = model.to(args.device) 122 | else: 123 | model = Network_Predict(args) 124 | # if torch.cuda.device_count() > 1: 125 | # model = nn.DataParallel(model) 126 | model = model.to(args.device) 127 | if args.mode == 'eval': 128 | load_dir = os.path.join(current_dir, '../SAVE', 'pretrain', args.model) 129 | model.load_state_dict(torch.load(load_dir + '/' + args.load_pretrain_path), strict=False) 130 | print(load_dir + '/' + args.load_pretrain_path) 131 | print('load pretrain model!!!') 132 | 133 | print_model_parameters(model, only_num=False) 134 | 135 | #init loss function, optimizer 136 | if args.loss_func == 'mask_mae': 137 | if (args.model == 'STSGCN' or args.model == 'STFGNN' or args.model == 'STGODE'): 138 | loss = scaler_huber_loss(mask_value=args.mape_thresh) 139 | print('============================scaler_huber_loss') 140 | else: 141 | loss = scaler_mae_loss(mask_value=args.mape_thresh) 142 | print('============================scaler_mae_loss') 143 | # print(args.model, Mode) 144 | elif args.loss_func == 'mae': 145 | loss = torch.nn.L1Loss().to(args.device) 146 | elif args.loss_func == 'mse': 147 | loss = torch.nn.MSELoss().to(args.device) 148 | else: 149 | raise ValueError 150 | 151 | 152 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr_init, eps=1.0e-8, 153 | weight_decay=0, amsgrad=False) 154 | #learning rate decay 155 | lr_scheduler = None 156 | if args.lr_decay: 157 | print('Applying learning rate decay.') 158 | lr_decay_steps = [int(i) for i in list(args.lr_decay_step.split(','))] 159 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, 160 | milestones=lr_decay_steps, 161 | gamma=args.lr_decay_rate) 162 | 163 | #start training 164 | loss_mse = torch.nn.MSELoss().to(args.device) 165 | loss_ssl = infoNCEloss() 166 | trainer = Trainer(model, loss, loss_ssl, optimizer, x_trn_dict, y_trn_dict, args.A_dict, args.lpls_dict, eval_train_loader, 167 | eval_val_loader, eval_test_loader, scaler_dict, eval_scaler_dict, args, 168 | lr_scheduler=lr_scheduler) 169 | 170 | if args.mode == 'pretrain': 171 | trainer.train_pretrain() 172 | elif args.mode == 'eval': 173 | trainer.train_eval() 174 | elif args.mode == 'ori': 175 | trainer.train_eval() 176 | elif args.mode == 'test': 177 | # model.load_state_dict(torch.load(log_dir + '/' + args.load_pretrain_path), strict=True) 178 | # print("Load saved model") 179 | trainer.eval_test(model, trainer.args, args.A_dict, args.lpls_dict, eval_test_loader, eval_scaler_dict[args.dataset_test], trainer.logger) 180 | else: 181 | raise ValueError 182 | -------------------------------------------------------------------------------- /model/STFGNN/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import configparser 4 | from lib.data_process import load_st_dataset 5 | from lib.predifineGraph import get_adjacency_matrix, load_pickle, weight_matrix 6 | import os 7 | # from fastdtw import fastdtw 8 | import torch 9 | import pandas as pd 10 | 11 | def gen_data(data, ntr, N, DATASET): 12 | ''' 13 | if flag: 14 | data=pd.read_csv(fname) 15 | else: 16 | data=pd.read_csv(fname,header=None) 17 | ''' 18 | #data=data.as_matrix() 19 | if DATASET == 'SZ_TAXI': 20 | data=np.reshape(data,[-1,96,N]) 21 | elif DATASET == 'NYC_TAXI' or DATASET == 'NYC_BIKE': 22 | data = np.reshape(data, [-1, 48, N]) 23 | else: 24 | data = np.reshape(data, [-1, 288, N]) 25 | return data[0:ntr] 26 | 27 | def normalize(a): 28 | mu=np.mean(a,axis=1,keepdims=True) 29 | std=np.std(a,axis=1,keepdims=True) 30 | return (a-mu)/std 31 | 32 | def compute_dtw(a,b,order=1,Ts=12,normal=True): 33 | if normal: 34 | a=normalize(a) 35 | b=normalize(b) 36 | T0=a.shape[1] 37 | d=np.reshape(a,[-1,1,T0])-np.reshape(b,[-1,T0,1]) 38 | d=np.linalg.norm(d,axis=0,ord=order) 39 | D=np.zeros([T0,T0]) 40 | for i in range(T0): 41 | for j in range(max(0,i-Ts),min(T0,i+Ts+1)): 42 | if (i==0) and (j==0): 43 | D[i,j]=d[i,j]**order 44 | continue 45 | if (i==0): 46 | D[i,j]=d[i,j]**order+D[i,j-1] 47 | continue 48 | if (j==0): 49 | D[i,j]=d[i,j]**order+D[i-1,j] 50 | continue 51 | if (j==i-Ts): 52 | D[i,j]=d[i,j]**order+min(D[i-1,j-1],D[i-1,j]) 53 | continue 54 | if (j==i+Ts): 55 | D[i,j]=d[i,j]**order+min(D[i-1,j-1],D[i,j-1]) 56 | continue 57 | D[i,j]=d[i,j]**order+min(D[i-1,j-1],D[i-1,j],D[i,j-1]) 58 | return D[-1,-1]**(1.0/order) 59 | 60 | def construct_dtw(data, DATASET): 61 | data = data[:, :, 0] 62 | if DATASET == 'SZ_TAXI': 63 | total_day = data.shape[0] / 96 64 | else: 65 | total_day = data.shape[0] / 288 66 | tr_day = int(total_day * 0.6) 67 | n_route = data.shape[1] 68 | xtr = gen_data(data, tr_day, n_route, DATASET) 69 | print(np.shape(xtr)) 70 | T0 = 288 71 | T = 12 72 | N = n_route 73 | d = np.zeros([N, N]) 74 | for i in range(N): 75 | for j in range(i + 1, N): 76 | d[i, j] = compute_dtw(xtr[:, :, i], xtr[:, :, j]) 77 | 78 | print("The calculation of time series is done!") 79 | dtw = d + d.T 80 | n = dtw.shape[0] 81 | w_adj = np.zeros([n, n]) 82 | adj_percent = 0.01 83 | top = int(n * adj_percent) 84 | for i in range(dtw.shape[0]): 85 | a = dtw[i, :].argsort()[0:top] 86 | for j in range(top): 87 | w_adj[i, a[j]] = 1 88 | 89 | for i in range(n): 90 | for j in range(n): 91 | if (w_adj[i][j] != w_adj[j][i] and w_adj[i][j] == 0): 92 | w_adj[i][j] = 1 93 | if (i == j): 94 | w_adj[i][j] = 1 95 | 96 | print("Total route number: ", n) 97 | print("Sparsity of adj: ", len(w_adj.nonzero()[0]) / (n * n)) 98 | print("The weighted matrix of temporal graph is generated!") 99 | dtw = w_adj 100 | return dtw 101 | 102 | def construct_adj_fusion(A, A_dtw, steps): 103 | ''' 104 | construct a bigger adjacency matrix using the given matrix 105 | 106 | Parameters 107 | ---------- 108 | A: np.ndarray, adjacency matrix, shape is (N, N) 109 | 110 | steps: how many times of the does the new adj mx bigger than A 111 | 112 | Returns 113 | ---------- 114 | new adjacency matrix: csr_matrix, shape is (N * steps, N * steps) 115 | 116 | ---------- 117 | This is 4N_1 mode: 118 | 119 | [T, 1, 1, T 120 | 1, S, 1, 1 121 | 1, 1, S, 1 122 | T, 1, 1, T] 123 | 124 | ''' 125 | 126 | N = len(A) 127 | adj = np.zeros([N * steps] * 2) # "steps" = 4 !!! 128 | 129 | for i in range(steps): 130 | if (i == 1) or (i == 2): 131 | adj[i * N: (i + 1) * N, i * N: (i + 1) * N] = A 132 | else: 133 | adj[i * N: (i + 1) * N, i * N: (i + 1) * N] = A_dtw 134 | #''' 135 | for i in range(N): 136 | for k in range(steps - 1): 137 | adj[k * N + i, (k + 1) * N + i] = 1 138 | adj[(k + 1) * N + i, k * N + i] = 1 139 | #''' 140 | adj[3 * N: 4 * N, 0: N] = A_dtw #adj[0 * N : 1 * N, 1 * N : 2 * N] 141 | adj[0 : N, 3 * N: 4 * N] = A_dtw #adj[0 * N : 1 * N, 1 * N : 2 * N] 142 | 143 | adj[2 * N: 3 * N, 0 : N] = adj[0 * N : 1 * N, 1 * N : 2 * N] 144 | adj[0 : N, 2 * N: 3 * N] = adj[0 * N : 1 * N, 1 * N : 2 * N] 145 | adj[1 * N: 2 * N, 3 * N: 4 * N] = adj[0 * N : 1 * N, 1 * N : 2 * N] 146 | adj[3 * N: 4 * N, 1 * N: 2 * N] = adj[0 * N : 1 * N, 1 * N : 2 * N] 147 | 148 | 149 | for i in range(len(adj)): 150 | adj[i, i] = 1 151 | 152 | return adj 153 | 154 | 155 | def parse_args(DATASET, args_base): 156 | # get configuration 157 | config_file = '../conf/STFGNN/{}.conf'.format(DATASET) 158 | config = configparser.ConfigParser() 159 | config.read(config_file) 160 | 161 | parser = argparse.ArgumentParser() 162 | # data 163 | parser.add_argument('--num_nodes', type=int, default=config['data']['num_nodes']) 164 | parser.add_argument('--window', type=int, default=config['data']['window']) 165 | parser.add_argument('--horizon', type=int, default=config['data']['horizon']) 166 | parser.add_argument('--order', type=int, default=config['data']['order']) 167 | parser.add_argument('--lag', type=int, default=config['data']['lag']) 168 | parser.add_argument('--period', type=int, default=config['data']['period']) 169 | parser.add_argument('--sparsity', type=float, default=config['data']['sparsity']) 170 | # model 171 | parser.add_argument('--hidden_dims', type=list, default=config['model']['hidden_dims']) 172 | parser.add_argument('--first_layer_embedding_size', type=int, default=config['model']['first_layer_embedding_size']) 173 | parser.add_argument('--out_layer_dim', type=int, default=config['model']['out_layer_dim']) 174 | parser.add_argument('--output_dim', type=int, default=config['model']['output_dim']) 175 | parser.add_argument('--strides', type=int, default=config['model']['strides']) 176 | parser.add_argument('--temporal_emb', type=eval, default=config['model']['temporal_emb']) 177 | parser.add_argument('--spatial_emb', type=eval, default=config['model']['spatial_emb']) 178 | parser.add_argument('--use_mask', type=eval, default=config['model']['use_mask']) 179 | parser.add_argument('--activation', type=str, default=config['model']['activation']) 180 | parser.add_argument('--module_type', type=str, default=config['model']['module_type']) 181 | 182 | 183 | args, _ = parser.parse_known_args() 184 | args.filepath = '../data/' + DATASET +'/' 185 | args.filename = DATASET 186 | data = load_st_dataset(DATASET, args_base) 187 | 188 | filename = DATASET 189 | 190 | if not os.path.exists(f'../data/STFGNN/{filename}/{filename}_adj_mx.npy'): 191 | if DATASET == 'PEMS07M': 192 | A = weight_matrix(args.filepath + DATASET + '.csv') 193 | elif DATASET == 'NYC_BIKE': 194 | A = pd.read_csv(args.filepath + DATASET + ".csv", header=None).values.astype(np.float32) 195 | elif DATASET == 'chengdu_didi': 196 | A = np.load(args.filepath + 'matrix.npy').astype(np.float32) 197 | elif DATASET == 'CA_District5': 198 | A = np.load(args.filepath + '.npy').astype(np.float32) 199 | else: 200 | A, Distance = get_adjacency_matrix( 201 | distance_df_filename=args.filepath + DATASET + '.csv', 202 | num_of_vertices=args.num_nodes) 203 | dtw = construct_dtw(data, DATASET) 204 | adj_mx = construct_adj_fusion(A, dtw, args.strides) 205 | np.save(f'../data/STFGNN/{filename}/{filename}_adj_mx.npy', adj_mx) 206 | 207 | adj_STFGNN = np.load(f'../data/STFGNN/{filename}/{filename}_adj_mx.npy') 208 | args.adj = torch.Tensor(adj_STFGNN) 209 | if DATASET == 'PEMS07M': 210 | args.hidden_dims = [[64, 64, 64], [64, 64, 64], [64, 64, 64]] 211 | else: 212 | args.hidden_dims = [[64, 64, 64]] 213 | return args 214 | 215 | -------------------------------------------------------------------------------- /model/STGCN/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import configparser 4 | from lib.predifineGraph import get_adjacency_matrix, load_pickle, weight_matrix 5 | import torch 6 | import pandas as pd 7 | 8 | def scaled_laplacian(W): 9 | ''' 10 | Normalized graph Laplacian function. 11 | :param W: np.ndarray, [n_route, n_route], weighted adjacency matrix of G. 12 | :return: np.matrix, [n_route, n_route]. 13 | ''' 14 | # d -> diagonal degree matrix 15 | n, d = np.shape(W)[0], np.sum(W, axis=1) 16 | # L -> graph Laplacian 17 | L = -W 18 | L[np.diag_indices_from(L)] = d 19 | # 返回索引以访问n维数组的主对角线。 20 | for i in range(n): 21 | for j in range(n): 22 | if (d[i] > 0) and (d[j] > 0): 23 | L[i, j] = L[i, j] / np.sqrt(d[i] * d[j]) 24 | # lambda_max \approx 2.0, the largest eigenvalues of L.L的最大特征值 25 | # lambda_max = eigs(L, k=1, which='LR')[0][0].real 26 | lambda_max = np.linalg.eigvals(L).max().real 27 | return np.mat(2 * L / lambda_max - np.identity(n)) 28 | 29 | 30 | def cheb_poly_approx(L, Ks, n): 31 | ''' 32 | Chebyshev polynomials approximation function. 33 | :param L: np.matrix, [n_route, n_route], graph Laplacian. 34 | :param Ks: int, kernel size of spatial convolution. 35 | :param n: int, number of routes / size of graph. 36 | :return: np.ndarray, [n_route, Ks*n_route]. 37 | ''' 38 | L0, L1 = np.mat(np.identity(n)), np.mat(np.copy(L)) 39 | 40 | if Ks > 1: 41 | L_list = [np.copy(L0), np.copy(L1)] 42 | for i in range(Ks - 2): 43 | Ln = np.mat(2 * L * L1 - L0) 44 | L_list.append(np.copy(Ln)) 45 | L0, L1 = np.matrix(np.copy(L1)), np.matrix(np.copy(Ln)) 46 | # L_lsit [Ks, n*n], Lk [n, Ks*n] 47 | return np.stack(L_list, axis=0) 48 | elif Ks == 1: 49 | return np.asarray(L0) 50 | else: 51 | raise ValueError(f'ERROR: the size of spatial kernel must be greater than 1, but received "{Ks}".') 52 | 53 | def parse_args(DATASET, args_base): 54 | # get configuration 55 | config_file = '../conf/STGCN/{}.conf'.format(DATASET) 56 | config = configparser.ConfigParser() 57 | config.read(config_file) 58 | parser = argparse.ArgumentParser(prefix_chars='--', description='predictor_based_arguments') 59 | 60 | blocks1_str = config.get('model', 'blocks1') 61 | blocks1 = eval(blocks1_str) 62 | # data 63 | parser.add_argument('--num_nodes', type=int, default=config['data']['num_nodes']) 64 | parser.add_argument('--input_window', type=int, default=config['data']['input_window']) 65 | parser.add_argument('--output_window', type=int, default=config['data']['output_window']) 66 | # model 67 | parser.add_argument('--Ks', type=int, default=config['model']['Ks']) 68 | parser.add_argument('--Kt', type=int, default=config['model']['Kt']) 69 | parser.add_argument('--blocks1', type=list, default=blocks1) 70 | parser.add_argument('--drop_prob', type=int, default=config['model']['drop_prob']) 71 | parser.add_argument('--outputl_ks', type=int, default=config['model']['outputl_ks']) 72 | args, _ = parser.parse_known_args() 73 | 74 | G_dict = {} 75 | for data_graph in args_base.dataset_graph: 76 | L = scaled_laplacian(args_base.A_dict_np[data_graph]) 77 | Lk = cheb_poly_approx(L, 3, args_base.A_dict_np[data_graph].shape[0]) 78 | G_dict[data_graph] = torch.FloatTensor(Lk) 79 | 80 | args.G_dict = G_dict 81 | return args -------------------------------------------------------------------------------- /model/STGCN/stgcn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from logging import getLogger 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | import torch.nn.functional as F 8 | 9 | 10 | class Align(nn.Module): 11 | def __init__(self, c_in, c_out): 12 | super(Align, self).__init__() 13 | self.c_in = c_in 14 | self.c_out = c_out 15 | if c_in > c_out: 16 | self.conv1x1 = nn.Conv2d(c_in, c_out, 1) # filter=(1,1) 17 | 18 | def forward(self, x): # x: (batch_size, feature_dim(c_in), input_length, num_nodes) 19 | if self.c_in > self.c_out: 20 | return self.conv1x1(x) 21 | if self.c_in < self.c_out: 22 | return F.pad(x, [0, 0, 0, 0, 0, self.c_out - self.c_in, 0, 0]) 23 | return x # return: (batch_size, c_out, input_length-1+1, num_nodes-1+1) 24 | 25 | class TemporalConvLayer(nn.Module): 26 | def __init__(self, kt, c_in, c_out, act="relu"): 27 | super(TemporalConvLayer, self).__init__() 28 | self.kt = kt 29 | self.act = act 30 | self.c_out = c_out 31 | self.align = Align(c_in, c_out) 32 | if self.act == "GLU": 33 | self.conv = nn.Conv2d(c_in, c_out * 2, (kt, 1), 1, padding=[int((kt-1)/2), 0]) 34 | else: 35 | self.conv = nn.Conv2d(c_in, c_out, (kt, 1), 1, padding=[int((kt-1)/2), 0]) 36 | 37 | def forward(self, x): 38 | """ 39 | 40 | :param x: (batch_size, feature_dim(c_in), input_length, num_nodes) 41 | :return: (batch_size, c_out, input_length-kt+1, num_nodes) 42 | """ 43 | # x_in = self.align(x)[:, :, self.kt - 1:, :] # (batch_size, c_out, input_length-kt+1, num_nodes) 44 | x_in = self.align(x)[:, :, :, :] # (batch_size, c_out, input_length-kt+1, num_nodes) 45 | if self.act == "GLU": 46 | # x: (batch_size, c_in, input_length, num_nodes) 47 | x_conv = self.conv(x) 48 | # x_conv: (batch_size, c_out * 2, input_length-kt+1, num_nodes) [P Q] 49 | return (x_conv[:, :self.c_out, :, :] + x_in) * torch.sigmoid(x_conv[:, self.c_out:, :, :]) 50 | # return P * sigmoid(Q) shape: (batch_size, c_out, input_length-kt+1, num_nodes) 51 | if self.act == "sigmoid": 52 | return torch.sigmoid(self.conv(x) + x_in) # residual connection 53 | return torch.relu(self.conv(x) + x_in) # residual connection 54 | 55 | class SpatioConvLayer(nn.Module): 56 | def __init__(self, ks, c_in, c_out, lk, device): 57 | super(SpatioConvLayer, self).__init__() 58 | self.Lk = lk 59 | self.device = device 60 | self.theta = nn.Parameter(torch.FloatTensor(c_in, c_out, ks).to(device)) # kernel: C_in*C_out*ks 61 | self.b = nn.Parameter(torch.FloatTensor(1, c_out, 1, 1).to(device)) 62 | self.align = Align(c_in, c_out) 63 | self.reset_parameters() 64 | 65 | def reset_parameters(self): 66 | init.kaiming_uniform_(self.theta, a=math.sqrt(5)) 67 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.theta) 68 | bound = 1 / math.sqrt(fan_in) 69 | init.uniform_(self.b, -bound, bound) 70 | 71 | def forward(self, x, select_dataset): 72 | # Lk: (Ks, num_nodes, num_nodes) 73 | # x: (batch_size, c_in, input_length, num_nodes) 74 | # x_c: (batch_size, c_in, input_length, Ks, num_nodes) 75 | # theta: (c_in, c_out, Ks) 76 | # x_gc: (batch_size, c_out, input_length, num_nodes) 77 | # print(select_dataset, self.Lk[select_dataset]) 78 | x_c = torch.einsum("knm,bitm->bitkn", self.Lk[select_dataset].to(self.device), x) # delete num_nodes(n) 79 | x_gc = torch.einsum("iok,bitkn->botn", self.theta, x_c) + self.b # delete Ks(k) c_in(i) 80 | x_in = self.align(x) # (batch_size, c_out, input_length, num_nodes) 81 | return torch.relu(x_gc + x_in) # residual connection 82 | 83 | 84 | class STConvBlock(nn.Module): 85 | def __init__(self, ks, kt, n, c, p, lk, dataset_use, dataset_test, device, mode): 86 | super(STConvBlock, self).__init__() 87 | self.mode = mode 88 | self.tconv1 = TemporalConvLayer(kt, c[0], c[1], "GLU") 89 | self.sconv = SpatioConvLayer(ks, c[1], c[1], lk, device) 90 | self.tconv2 = TemporalConvLayer(kt, c[1], c[2]) 91 | self.dataset2index = {} 92 | if self.mode == 'pretrain': 93 | self.ln_pretrain = nn.ModuleList() 94 | for i, data_graph in enumerate(dataset_use): 95 | self.dataset2index[data_graph] = i 96 | n_dataset = lk[data_graph].shape[1] 97 | self.ln_pretrain.append(nn.LayerNorm([n_dataset, c[2]])) 98 | else: 99 | self.ln_eval = nn.ModuleList() 100 | for i, data_graph in enumerate(dataset_test): 101 | self.dataset2index[data_graph] = i 102 | n_dataset = lk[data_graph].shape[1] 103 | self.ln_eval.append(nn.LayerNorm([n_dataset, c[2]])) 104 | self.dropout = nn.Dropout(p) 105 | 106 | def forward(self, x, select_dataset): # x: (batch_size, feature_dim/c[0], input_length, num_nodes) 107 | x_t1 = self.tconv1(x) # (batch_size, c[1], input_length-kt+1, num_nodes) 108 | x_s = self.sconv(x_t1, select_dataset) # (batch_size, c[1], input_length-kt+1, num_nodes) 109 | x_t2 = self.tconv2(x_s) # (batch_size, c[2], input_length-kt+1-kt+1, num_nodes) 110 | x_t2 = x_t2.permute(0, 2, 3, 1).permute(0, 3, 1, 2) 111 | if self.mode == 'pretrain': 112 | x_ln = self.ln_pretrain[self.dataset2index[select_dataset]](x_t2.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 113 | else: 114 | x_ln = self.ln_eval[self.dataset2index[select_dataset]](x_t2.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 115 | return self.dropout(x_ln) 116 | 117 | class FullyConvLayer(nn.Module): 118 | def __init__(self, c, out_dim): 119 | super(FullyConvLayer, self).__init__() 120 | self.conv = nn.Conv2d(c, out_dim, 1) # c,self.output_dim,1 121 | 122 | def forward(self, x): 123 | return self.conv(x) 124 | 125 | 126 | class OutputLayer(nn.Module): 127 | def __init__(self, c, t, n, out_dim, lk, dataset_use, dataset_test, mode): 128 | super(OutputLayer, self).__init__() 129 | self.tconv1 = TemporalConvLayer(t, c, c, "GLU") 130 | self.dataset2index = {} 131 | self.mode = mode 132 | self.fc_pretrain = FullyConvLayer(c, out_dim) 133 | if self.mode == 'pretrain': 134 | self.ln_pretrain = nn.ModuleList() 135 | for i, data_graph in enumerate(dataset_use): 136 | self.dataset2index[data_graph] = i 137 | n_dataset = lk[data_graph].shape[1] 138 | self.ln_pretrain.append(nn.LayerNorm([n_dataset, c])) 139 | else: 140 | self.ln_eval = nn.ModuleList() 141 | for i, data_graph in enumerate(dataset_test): 142 | self.dataset2index[data_graph] = i 143 | n_dataset = lk[data_graph].shape[1] 144 | self.ln_eval.append(nn.LayerNorm([n_dataset, c])) 145 | # self.fc_eval = FullyConvLayer(c, out_dim) 146 | self.tconv2 = TemporalConvLayer(1, c, c, "sigmoid") # kernel=1*1 147 | 148 | def forward(self, x, select_dataset): 149 | # (batch_size, input_dim(c), T, num_nodes) 150 | x_t1 = self.tconv1(x) 151 | # (batch_size, input_dim(c), 1, num_nodes) 152 | if self.mode == 'pretrain': 153 | x_t1 = self.ln_pretrain[self.dataset2index[select_dataset]](x_t1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 154 | else: 155 | x_t1 = self.ln_eval[self.dataset2index[select_dataset]](x_t1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 156 | # x_t1 = self.ln(x_t1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 157 | x_ln = x_t1.permute(0, 2, 3, 1).permute(0, 3, 1, 2) 158 | # (batch_size, input_dim(c), 1, num_nodes) 159 | x_t2 = self.tconv2(x_ln) 160 | return self.fc_pretrain(x_t2) 161 | 162 | 163 | class STGCN(nn.Module): 164 | def __init__(self, args, G_dict, dataset_use, dataset_test, num_nodes, dim_in, dim_out, device, mode): 165 | super(STGCN, self).__init__() 166 | self.Ks = args.Ks 167 | self.Kt = args.Kt 168 | self.num_nodes = num_nodes 169 | self.G_dict = G_dict 170 | self.blocks0 = [dim_in, args.blocks1[1], args.blocks1[0]] 171 | self.blocks1 = args.blocks1 172 | self.drop_prob = args.drop_prob 173 | self.device = device 174 | self.st_conv1 = STConvBlock(self.Ks, self.Kt, self.num_nodes, 175 | self.blocks0, self.drop_prob, self.G_dict, dataset_use, dataset_test, self.device, mode) 176 | self.st_conv2 = STConvBlock(self.Ks, self.Kt, self.num_nodes, 177 | self.blocks1, self.drop_prob, self.G_dict, dataset_use, dataset_test, self.device, mode) 178 | self.output = OutputLayer(args.blocks1[2], args.outputl_ks, self.num_nodes, dim_out, self.G_dict, dataset_use, dataset_test, mode) 179 | 180 | def forward(self, x, select_dataset): 181 | # print(x.shape) 182 | x = x.permute(0, 3, 1, 2) # (batch_size, feature_dim, input_length, num_nodes) 183 | # print(x.shape) 184 | x_st1 = self.st_conv1(x, select_dataset) # (batch_size, c[2](64), input_length-kt+1-kt+1, num_nodes) 185 | # print(x_st1.shape) 186 | x_st2 = self.st_conv2(x_st1, select_dataset) # (batch_size, c[2](128), input_length-kt+1-kt+1-kt+1-kt+1, num_nodes) 187 | # print(x_st2.shape) 188 | outputs1 = self.output(x_st2, select_dataset) # (batch_size, output_dim(1), output_length(1), num_nodes) 189 | # print(outputs.shape) 190 | outputs2 = outputs1.permute(0, 2, 3, 1) # (batch_size, output_length(1), num_nodes, output_dim) 191 | # print(outputs2.shape) 192 | return outputs2 -------------------------------------------------------------------------------- /model/STGODE/STGODE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .odegcn import ODEG 7 | 8 | 9 | class Chomp1d(nn.Module): 10 | """ 11 | extra dimension will be added by padding, remove it 12 | """ 13 | 14 | def __init__(self, chomp_size): 15 | super(Chomp1d, self).__init__() 16 | self.chomp_size = chomp_size 17 | 18 | def forward(self, x): 19 | return x[:, :, :, :-self.chomp_size].contiguous() 20 | 21 | 22 | class TemporalConvNet(nn.Module): 23 | """ 24 | time dilation convolution 25 | """ 26 | 27 | def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2): 28 | """ 29 | Args: 30 | num_inputs : channel's number of input data's feature 31 | num_channels : numbers of data feature tranform channels, the last is the output channel 32 | kernel_size : using 1d convolution, so the real kernel is (1, kernel_size) 33 | """ 34 | super(TemporalConvNet, self).__init__() 35 | layers = [] 36 | num_levels = len(num_channels) 37 | for i in range(num_levels): 38 | dilation_size = 2 ** i 39 | in_channels = num_inputs if i == 0 else num_channels[i - 1] 40 | out_channels = num_channels[i] 41 | padding = (kernel_size - 1) * dilation_size 42 | self.conv = nn.Conv2d(in_channels, out_channels, (1, kernel_size), dilation=(1, dilation_size), 43 | padding=(0, padding)) 44 | self.conv.weight.data.normal_(0, 0.01) 45 | self.chomp = Chomp1d(padding) 46 | self.relu = nn.ReLU() 47 | self.dropout = nn.Dropout(dropout) 48 | 49 | layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)] 50 | 51 | self.network = nn.Sequential(*layers) 52 | self.downsample = nn.Conv2d(num_inputs, num_channels[-1], (1, 1)) if num_inputs != num_channels[-1] else None 53 | if self.downsample: 54 | self.downsample.weight.data.normal_(0, 0.01) 55 | 56 | def forward(self, x, select_dataset=None): 57 | """ 58 | like ResNet 59 | Args: 60 | X : input data of shape (B, N, T, F) 61 | """ 62 | # permute shape to (B, F, N, T) 63 | y = x.permute(0, 3, 1, 2) 64 | y = F.relu(self.network(y) + self.downsample(y) if self.downsample else y) 65 | y = y.permute(0, 2, 3, 1) 66 | return y 67 | 68 | 69 | class GCN(nn.Module): 70 | def __init__(self, A_hat, in_channels, out_channels, ): 71 | super(GCN, self).__init__() 72 | self.A_hat = A_hat 73 | self.theta = nn.Parameter(torch.FloatTensor(in_channels, out_channels)) 74 | self.reset() 75 | 76 | def reset(self): 77 | stdv = 1. / math.sqrt(self.theta.shape[1]) 78 | self.theta.data.uniform_(-stdv, stdv) 79 | 80 | def forward(self, X): 81 | y = torch.einsum('ij, kjlm-> kilm', self.A_hat, X) 82 | return F.relu(torch.einsum('kjlm, mn->kjln', y, self.theta)) 83 | 84 | 85 | class STGCNBlock(nn.Module): 86 | def __init__(self, time_steps, in_channels, out_channels, num_nodes, A_hat, A_dict, dataset_use, dataset_test, mode, device): 87 | """ 88 | Args: 89 | in_channels: Number of input features at each node in each time step. 90 | out_channels: a list of feature channels in timeblock, the last is output feature channel 91 | num_nodes: Number of nodes in the graph 92 | A_hat: the normalized adjacency matrix 93 | """ 94 | super(STGCNBlock, self).__init__() 95 | self.A_hat = A_hat 96 | self.temporal1 = TemporalConvNet(num_inputs=in_channels, 97 | num_channels=out_channels) 98 | self.odeg = ODEG(out_channels[-1], time_steps, A_hat, time=6, dataset_use=dataset_use, dataset_test=dataset_test, mode=mode, device=device) 99 | self.temporal2 = TemporalConvNet(num_inputs=out_channels[-1], 100 | num_channels=out_channels) 101 | 102 | self.dataset2index = {} 103 | self.mode = mode 104 | if mode == 'pretrain': 105 | self.bn_pretrain = nn.ModuleList() 106 | for i, data_graph in enumerate(dataset_use): 107 | self.dataset2index[data_graph] = i 108 | n_dataset = A_dict[data_graph].shape[0] 109 | self.bn_pretrain.append(nn.BatchNorm2d(n_dataset)) 110 | else: 111 | self.bn_eval = nn.ModuleList() 112 | for i, data_graph in enumerate([dataset_test]): 113 | self.dataset2index[data_graph] = i 114 | n_dataset = A_dict[data_graph].shape[0] 115 | self.bn_eval.append(nn.BatchNorm2d(n_dataset)) 116 | 117 | # self.batch_norm = nn.BatchNorm2d(num_nodes) 118 | 119 | def forward(self, X, select_dataset): 120 | # """ 121 | # Args: 122 | # X: Input data of shape (batch_size, num_nodes, num_timesteps, num_features) 123 | # Return: 124 | # Output data of shape(batch_size, num_nodes, num_timesteps, out_channels[-1]) 125 | # """ 126 | t = self.temporal1(X) 127 | t = self.odeg(t, select_dataset) 128 | t = self.temporal2(F.relu(t)) 129 | 130 | if self.mode == 'pretrain': 131 | return self.bn_pretrain[self.dataset2index[select_dataset]](t) 132 | else: 133 | return self.bn_eval[self.dataset2index[select_dataset]](t) 134 | # return self.batch_norm(t) 135 | 136 | 137 | class ODEGCN(nn.Module): 138 | """ the overall network framework """ 139 | 140 | def __init__(self, num_nodes, num_features, num_timesteps_input, 141 | num_timesteps_output, A_sp_hat, A_se_hat, dim_out, A_dict, dataset_use, dataset_test, mode, device): 142 | """ 143 | Args: 144 | num_nodes : number of nodes in the graph 145 | num_features : number of features at each node in each time step 146 | num_timesteps_input : number of past time steps fed into the network 147 | num_timesteps_output : desired number of future time steps output by the network 148 | A_sp_hat : nomarlized adjacency spatial matrix 149 | A_se_hat : nomarlized adjacency semantic matrix 150 | """ 151 | 152 | super(ODEGCN, self).__init__() 153 | 154 | # spatial graph 155 | self.sp_blocks1 = nn.ModuleList( 156 | [ 157 | STGCNBlock(num_timesteps_input, in_channels=num_features, out_channels=[64, 32, 64], 158 | num_nodes=num_nodes, A_hat=A_sp_hat, A_dict=A_dict, dataset_use=dataset_use, 159 | dataset_test=dataset_test, mode=mode, device=device) for _ in range(3) 160 | ]) 161 | self.sp_blocks2 = nn.ModuleList( 162 | [ 163 | STGCNBlock(num_timesteps_input, in_channels=64, out_channels=[64, 32, 64], 164 | num_nodes=num_nodes, A_hat=A_sp_hat, A_dict=A_dict, dataset_use=dataset_use, 165 | dataset_test=dataset_test, mode=mode, device=device) for _ in range(3) 166 | ]) 167 | 168 | 169 | # self.sp_blocks = nn.ModuleList( 170 | # [nn.Sequential( 171 | # STGCNBlock(num_timesteps_input, in_channels=num_features, out_channels=[64, 32, 64], 172 | # num_nodes=num_nodes, A_hat=A_sp_hat, A_dict=A_dict, dataset_use=dataset_use, 173 | # dataset_test=dataset_test, mode=mode), 174 | # STGCNBlock(num_timesteps_input, in_channels=64, out_channels=[64, 32, 64], 175 | # num_nodes=num_nodes, A_hat=A_sp_hat, A_dict=A_dict, dataset_use=dataset_use, 176 | # dataset_test=dataset_test, mode=mode)) for _ in range(3) 177 | # ]) 178 | 179 | # semantic graph 180 | self.se_blocks1 = nn.ModuleList([ 181 | STGCNBlock(num_timesteps_input, in_channels=num_features, out_channels=[64, 32, 64], 182 | num_nodes=num_nodes, A_hat=A_se_hat, A_dict=A_dict, dataset_use=dataset_use, 183 | dataset_test=dataset_test, mode=mode, device=device) for _ in range(3) 184 | ]) 185 | 186 | self.se_blocks2 = nn.ModuleList([ 187 | STGCNBlock(num_timesteps_input, in_channels=64, out_channels=[64, 32, 64], 188 | num_nodes=num_nodes, A_hat=A_se_hat, A_dict=A_dict, dataset_use=dataset_use, 189 | dataset_test=dataset_test, mode=mode, device=device) for _ in range(3) 190 | ]) 191 | 192 | # self.se_blocks = nn.ModuleList([nn.Sequential( 193 | # STGCNBlock(num_timesteps_input, in_channels=num_features, out_channels=[64, 32, 64], 194 | # num_nodes=num_nodes, A_hat=A_se_hat, A_dict=A_dict, dataset_use=dataset_use, 195 | # dataset_test=dataset_test, mode=mode), 196 | # STGCNBlock(num_timesteps_input, in_channels=64, out_channels=[64, 32, 64], 197 | # num_nodes=num_nodes, A_hat=A_se_hat, A_dict=A_dict, dataset_use=dataset_use, 198 | # dataset_test=dataset_test, mode=mode)) for _ in range(3) 199 | # ]) 200 | 201 | self.pred = nn.Sequential( 202 | nn.Linear(num_timesteps_input * 64, num_timesteps_output * 32), 203 | nn.ReLU(), 204 | nn.Linear(num_timesteps_output * 32, num_timesteps_output * dim_out) 205 | ) 206 | self.dim_out = dim_out 207 | 208 | def forward(self, x, select_dataset): 209 | # """ 210 | # Args: 211 | # x : input data of shape (batch_size, num_nodes, num_timesteps, num_features) == (B, N, T, F) 212 | # Returns: 213 | # prediction for future of shape (batch_size, num_nodes, num_timesteps_output) 214 | # """ 215 | x = x.transpose(1, 2) 216 | outs = [] 217 | 218 | # spatial graph 219 | for blk1, blk2 in zip(self.sp_blocks1, self.sp_blocks2): 220 | x1 = blk1(x, select_dataset) 221 | outs.append(blk2(x1, select_dataset)) 222 | # semantic graph 223 | for blk1, blk2 in zip(self.se_blocks1, self.se_blocks2): 224 | x1 = blk1(x, select_dataset) 225 | outs.append(blk2(x1, select_dataset)) 226 | outs = torch.stack(outs) 227 | x = torch.max(outs, dim=0)[0] 228 | x = x.reshape((x.shape[0], x.shape[1], -1)) 229 | if self.dim_out != 1: 230 | batch, node_num, time_step = x.shape[0], x.shape[1], x.shape[2] 231 | out_pred = self.pred(x).unsqueeze(-1).reshape(batch, node_num, -1, self.dim_out).transpose(1, 2) 232 | else: 233 | out_pred = self.pred(x).unsqueeze(-1).transpose(1, 2) 234 | # print(out_pred.shape) 235 | 236 | return out_pred 237 | -------------------------------------------------------------------------------- /model/STGODE/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | from fastdtw import fastdtw 5 | from tqdm import tqdm 6 | import csv 7 | import torch 8 | import pandas as pd 9 | 10 | def read_data(args): 11 | """read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw 12 | 13 | Args: 14 | sigma1: float, default=0.1, sigma for the semantic matrix 15 | sigma2: float, default=10, sigma for the spatial matrix 16 | thres1: float, default=0.6, the threshold for the semantic matrix 17 | thres2: float, default=0.5, the threshold for the spatial matrix 18 | 19 | Returns: 20 | data: tensor, T * N * 1 21 | dtw_matrix: array, semantic adjacency matrix 22 | sp_matrix: array, spatial adjacency matrix 23 | """ 24 | filename = args.filename 25 | # filepath = "./data/" 26 | filepath = args.filepath 27 | # if args.remote: 28 | # filepath = '/home/lantu.lqq/ftemp/data/' 29 | data = np.load(filepath + filename + '.npz')['data'] 30 | if data.ndim != 3: 31 | data = np.expand_dims(data, axis=-1) 32 | # print(data.shape) 33 | # PEMS04 == shape: (16992, 307, 3) feature: flow,occupy,speed 34 | # PEMSD7M == shape: (12672, 228, 1) 35 | # PEMSD7L == shape: (12672, 1026, 1) 36 | num_node = data.shape[1] 37 | mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1) 38 | std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1) 39 | data = (data - mean_value) / std_value 40 | mean_value = mean_value.reshape(-1)[0] 41 | std_value = std_value.reshape(-1)[0] 42 | # print(sss) 43 | 44 | if not os.path.exists(f'../data/STGODE/{filename}/{filename}_dtw_distance.npy'): 45 | data_mean = np.mean([data[:, :, 0][24 * 12 * i: 24 * 12 * (i + 1)] for i in range(data.shape[0] // (24 * 12))], 46 | axis=0) 47 | data_mean = data_mean.squeeze().T 48 | dtw_distance = np.zeros((num_node, num_node)) 49 | for i in tqdm(range(num_node)): 50 | for j in range(i, num_node): 51 | dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0] 52 | for i in range(num_node): 53 | for j in range(i): 54 | dtw_distance[i][j] = dtw_distance[j][i] 55 | np.save(f'../data/STGODE/{filename}/{filename}_dtw_distance.npy', dtw_distance) 56 | 57 | dist_matrix = np.load(f'../data/STGODE/{filename}/{filename}_dtw_distance.npy') 58 | 59 | mean = np.mean(dist_matrix) 60 | std = np.std(dist_matrix) 61 | dist_matrix = (dist_matrix - mean) / std 62 | sigma = args.sigma1 63 | dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2) 64 | dtw_matrix = np.zeros_like(dist_matrix) 65 | dtw_matrix[dist_matrix > args.thres1] = 1 66 | 67 | # # use continuous semantic matrix 68 | # if not os.path.exists(f'data/{filename}_dtw_c_matrix.npy'): 69 | # dist_matrix = np.load(f'data/{filename}_dtw_distance.npy') 70 | # # normalization 71 | # std = np.std(dist_matrix[dist_matrix != np.float('inf')]) 72 | # mean = np.mean(dist_matrix[dist_matrix != np.float('inf')]) 73 | # dist_matrix = (dist_matrix - mean) / std 74 | # sigma = 0.1 75 | # dtw_matrix = np.exp(- dist_matrix**2 / sigma**2) 76 | # dtw_matrix[dtw_matrix < 0.5] = 0 77 | # np.save(f'data/{filename}_dtw_c_matrix.npy', dtw_matrix) 78 | # dtw_matrix = np.load(f'data/{filename}_dtw_c_matrix.npy') 79 | 80 | # use continuous spatial matrix 81 | if not os.path.exists(f'../data/STGODE/{filename}/{filename}_spatial_distance.npy'): 82 | if filename == 'PEMS07M': 83 | dist_matrix = pd.read_csv(filepath + filename + '.csv', header=None).values 84 | elif filename == 'NYC_BIKE': 85 | dist_matrix = pd.read_csv(filepath + filename + '.csv', header=None).values.astype(np.float32) 86 | dist_matrix = (1 - dist_matrix) * 1000 87 | elif filename == 'chengdu_didi': 88 | dist_matrix = np.load(filepath + 'matrix.npy').astype(np.float32) 89 | dist_matrix = (1 - dist_matrix) * 1000 90 | elif filename == 'CA_District5': 91 | dist_matrix = np.load(filepath + filename + '.npy').astype(np.float32) 92 | dist_matrix = (1 - dist_matrix) * 1000 93 | elif filename == 'PEMS03': 94 | with open(filepath + filename + '.txt', 'r') as f: 95 | id_dict = {int(i): idx for idx, i in 96 | enumerate(f.read().strip().split('\n'))} # 把节点id(idx)映射成从0开始的索引 97 | f.readline() 98 | reader = csv.reader(f) 99 | for row in reader: 100 | if len(row) != 3: 101 | continue 102 | i, j, distance = int(row[0]), int(row[1]), float(row[2]) 103 | dist_matrix[id_dict[i], id_dict[j]] = distance 104 | else: 105 | dist_matrix = np.zeros((num_node, num_node)) + np.float('inf') 106 | with open(filepath + filename + '.csv', 'r') as fp: 107 | file = csv.reader(fp) 108 | for line in file: 109 | break 110 | for line in file: 111 | print(line) 112 | print(line[0]) 113 | start = int(line[0]) 114 | end = int(line[1]) 115 | dist_matrix[start][end] = float(line[2]) 116 | dist_matrix[end][start] = float(line[2]) 117 | np.save(f'../data/STGODE/{filename}/{filename}_spatial_distance.npy', dist_matrix) 118 | 119 | # use 0/1 spatial matrix 120 | # if not os.path.exists(f'data/{filename}_sp_matrix.npy'): 121 | # dist_matrix = np.load(f'data/{filename}_spatial_distance.npy') 122 | # sp_matrix = np.zeros((num_node, num_node)) 123 | # sp_matrix[dist_matrix != np.float('inf')] = 1 124 | # np.save(f'data/{filename}_sp_matrix.npy', sp_matrix) 125 | # sp_matrix = np.load(f'data/{filename}_sp_matrix.npy') 126 | 127 | dist_matrix = np.load(f'../data/STGODE/{filename}/{filename}_spatial_distance.npy') 128 | # normalization 129 | std = np.std(dist_matrix[dist_matrix != np.float('inf')]) 130 | mean = np.mean(dist_matrix[dist_matrix != np.float('inf')]) 131 | dist_matrix = (dist_matrix - mean) / std 132 | sigma = args.sigma2 133 | sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2) 134 | sp_matrix[sp_matrix < args.thres2] = 0 135 | # np.save(f'data/{filename}_sp_c_matrix.npy', sp_matrix) 136 | # sp_matrix = np.load(f'data/{filename}_sp_c_matrix.npy') 137 | 138 | print(f'average degree of spatial graph is {np.sum(sp_matrix > 0) / 2 / num_node}') 139 | print(f'average degree of semantic graph is {np.sum(dtw_matrix > 0) / 2 / num_node}') 140 | return torch.from_numpy(data.astype(np.float32)), mean_value, std_value, dtw_matrix, sp_matrix 141 | 142 | def get_normalized_adj(A): 143 | """ 144 | Returns a tensor, the degree normalized adjacency matrix. 145 | """ 146 | alpha = 0.8 147 | D = np.array(np.sum(A, axis=1)).reshape((-1,)) 148 | D[D <= 10e-5] = 10e-5 # Prevent infs 149 | diag = np.reciprocal(np.sqrt(D)) 150 | A_wave = np.multiply(np.multiply(diag.reshape((-1, 1)), A), 151 | diag.reshape((1, -1))) 152 | A_reg = alpha / 2 * (np.eye(A.shape[0]) + A_wave) 153 | return torch.from_numpy(A_reg.astype(np.float32)) 154 | 155 | 156 | def parse_args(DATASET, args_base): 157 | parser = argparse.ArgumentParser() 158 | # parser.add_argument('--num-gpu', type=int, default=0, help='the number of the gpu to use') 159 | # parser.add_argument('--batch-size', type=int, default=16, help='batch size') 160 | 161 | # parser.add_argument('--filename', type=str, default='pems08') 162 | # parser.add_argument('--filepath', type=str, default='F:/data/traffic_data/PEMS_data/') 163 | # parser.add_argument('--train-ratio', type=float, default=0.6, help='the ratio of training dataset') 164 | # parser.add_argument('--valid-ratio', type=float, default=0.2, help='the ratio of validating dataset') 165 | # parser.add_argument('--his-length', type=int, default=12, help='the length of history time series of input') 166 | # parser.add_argument('--pred-length', type=int, default=12, help='the length of target time series for prediction') 167 | 168 | parser.add_argument('--sigma1', type=float, default=0.1, help='sigma for the semantic matrix') 169 | parser.add_argument('--sigma2', type=float, default=10, help='sigma for the spatial matrix') 170 | parser.add_argument('--thres1', type=float, default=0.6, help='the threshold for the semantic matrix') 171 | parser.add_argument('--thres2', type=float, default=0.5, help='the threshold for the spatial matrix') 172 | # parser.add_argument('--lr', type=float, default=2e-3, help='learning rate') 173 | 174 | args, _ = parser.parse_known_args() 175 | 176 | A_sp_wave_dict = {} 177 | A_se_wave_dict = {} 178 | 179 | for data_graph in args_base.dataset_graph: 180 | args.filepath = '../data/' + data_graph +'/' 181 | args.filename = data_graph 182 | _, _, _, dtw_matrix, sp_matrix = read_data(args) 183 | A_sp_wave_dict[data_graph], A_se_wave_dict[data_graph] = get_normalized_adj(sp_matrix).to(args_base.device), get_normalized_adj(dtw_matrix).to(args_base.device) 184 | # args.A_sp_wave, args.A_se_wave = get_normalized_adj(sp_matrix), get_normalized_adj(dtw_matrix) 185 | args.A_sp_wave_dict, args.A_se_wave_dict = A_sp_wave_dict, A_se_wave_dict 186 | return args -------------------------------------------------------------------------------- /model/STGODE/odegcn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | # Whether use adjoint method or not. 7 | adjoint = False 8 | if adjoint: 9 | from torchdiffeq import odeint_adjoint as odeint 10 | else: 11 | from torchdiffeq import odeint 12 | 13 | 14 | # Define the ODE function. 15 | # Input: 16 | # --- t: A tensor with shape [], meaning the current time. 17 | # --- x: A tensor with shape [#batches, dims], meaning the value of x at t. 18 | # Output: 19 | # --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t. 20 | class ODEFunc(nn.Module): 21 | 22 | def __init__(self, feature_dim, temporal_dim, adj, dataset_use, dataset_test, mode, device): 23 | super(ODEFunc, self).__init__() 24 | self.adj = adj 25 | self.x0 = None 26 | self.mode = mode 27 | self.dataset2index = {} 28 | if mode == 'pretrain': 29 | self.alpha_pretrain = [] 30 | for i, data_graph in enumerate(dataset_use): 31 | self.dataset2index[data_graph] = i 32 | n_dataset = adj[data_graph].shape[1] 33 | self.alpha_pretrain.append(nn.Parameter(0.8 * torch.ones(n_dataset)).to(device)) 34 | else: 35 | self.alpha_eval = [] 36 | for i, data_graph in enumerate([dataset_test]): 37 | self.dataset2index[data_graph] = i 38 | n_dataset = adj[data_graph].shape[1] 39 | self.alpha_eval.append(nn.Parameter(0.8 * torch.ones(n_dataset)).to(device)) 40 | 41 | # self.alpha = nn.Parameter(0.8 * torch.ones(adj.shape[1])) 42 | self.beta = 0.6 43 | self.w = nn.Parameter(torch.eye(feature_dim)) 44 | self.d = nn.Parameter(torch.zeros(feature_dim) + 1) 45 | self.w2 = nn.Parameter(torch.eye(temporal_dim)) 46 | self.d2 = nn.Parameter(torch.zeros(temporal_dim) + 1) 47 | 48 | def forward(self, t, x, select_dataset): 49 | if self.mode == 'pretrain': 50 | alpha = torch.sigmoid(self.alpha_pretrain[self.dataset2index[select_dataset]]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0) 51 | else: 52 | alpha = torch.sigmoid(self.alpha_eval[self.dataset2index[select_dataset]]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0) 53 | 54 | # alpha = torch.sigmoid(self.alpha).unsqueeze(-1).unsqueeze(-1).unsqueeze(0) 55 | # print(self.adj.shape, x.shape) 56 | xa = torch.einsum('ij, kjlm->kilm', self.adj[select_dataset], x) 57 | 58 | # ensure the eigenvalues to be less than 1 59 | d = torch.clamp(self.d, min=0, max=1) 60 | w = torch.mm(self.w * d, torch.t(self.w)) 61 | xw = torch.einsum('ijkl, lm->ijkm', x, w) 62 | 63 | d2 = torch.clamp(self.d2, min=0, max=1) 64 | w2 = torch.mm(self.w2 * d2, torch.t(self.w2)) 65 | xw2 = torch.einsum('ijkl, km->ijml', x, w2) 66 | 67 | f = alpha / 2 * xa - x + xw - x + xw2 - x + self.x0 68 | return f 69 | 70 | 71 | class ODEblock(nn.Module): 72 | def __init__(self, odefunc, t=torch.tensor([0,1])): 73 | super(ODEblock, self).__init__() 74 | self.t = t 75 | self.odefunc = odefunc 76 | 77 | def set_x0(self, x0): 78 | self.odefunc.x0 = x0.clone().detach() 79 | 80 | 81 | 82 | def forward(self, x, select_dataset): 83 | def wrapped_odefunc(t, x): 84 | return self.odefunc(t, x, select_dataset) 85 | t = self.t.type_as(x) 86 | z = odeint(wrapped_odefunc, x, t, method='euler')[1] 87 | # z = odeint(self.odefunc, x, t, method='euler')[1] 88 | return z 89 | 90 | 91 | # Define the ODEGCN model. 92 | class ODEG(nn.Module): 93 | def __init__(self, feature_dim, temporal_dim, adj, time, dataset_use, dataset_test, mode, device): 94 | super(ODEG, self).__init__() 95 | self.odeblock = ODEblock(ODEFunc(feature_dim, temporal_dim, adj, dataset_use, dataset_test, mode, device), t=torch.tensor([0, time])) 96 | 97 | def forward(self, x, select_dataset=None): 98 | self.odeblock.set_x0(x) 99 | z = self.odeblock(x, select_dataset) 100 | return F.relu(z) 101 | -------------------------------------------------------------------------------- /model/STSGCN/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import configparser 4 | from lib.predifineGraph import get_adjacency_matrix, load_pickle, weight_matrix 5 | import torch 6 | import pandas as pd 7 | 8 | def parse_args(DATASET): 9 | # get configuration 10 | config_file = '../conf/STSGCN/{}.conf'.format(DATASET) 11 | config = configparser.ConfigParser() 12 | config.read(config_file) 13 | parser = argparse.ArgumentParser() 14 | 15 | filter_list_str = config.get('model', 'filter_list') 16 | filter_list = eval(filter_list_str) 17 | 18 | # data 19 | parser.add_argument('--num_nodes', type=int, default=config['data']['num_nodes']) 20 | parser.add_argument('--input_window', type=int, default=config['data']['input_window']) 21 | parser.add_argument('--output_window', type=int, default=config['data']['output_window']) 22 | 23 | # model 24 | parser.add_argument('--filter_list', type=list, default=config['model']['filter_list']) 25 | parser.add_argument('--rho', type=int, default=config['model']['rho']) 26 | parser.add_argument('--feature_dim', type=int, default=config['model']['feature_dim']) 27 | parser.add_argument('--module_type', type=str, default=config['model']['module_type']) 28 | parser.add_argument('--activation', type=str, default=config['model']['activation']) 29 | parser.add_argument('--temporal_emb', type=eval, default=config['model']['temporal_emb']) 30 | parser.add_argument('--spatial_emb', type=eval, default=config['model']['spatial_emb']) 31 | parser.add_argument('--use_mask', type=eval, default=config['model']['use_mask']) 32 | parser.add_argument('--steps', type=int, default=config['model']['steps']) 33 | parser.add_argument('--first_layer_embedding_size', type=int, default=config['model']['first_layer_embedding_size']) 34 | 35 | args, _ = parser.parse_known_args() 36 | args.filter_list = filter_list 37 | args.adj_mx = None 38 | return args -------------------------------------------------------------------------------- /model/ST_WA/ST_WA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .attention import TemporalAttention, SpatialAttention 4 | # from util import reparameterize 5 | 6 | def reparameterize(mu, logvar): 7 | std = torch.exp(0.5 * logvar) 8 | eps = torch.randn_like(std) 9 | return mu + eps * std 10 | 11 | class STWA(nn.Module): 12 | def __init__(self, device, num_nodes, input_dim, output_dim, channels, dynamic, lag, horizon, supports, 13 | memory_size, A_dict, dataset_use, dataset_test, mode): 14 | super(STWA, self).__init__() 15 | self.supports = supports 16 | self.num_nodes = num_nodes 17 | self.output_dim = output_dim 18 | self.channels = channels 19 | self.dynamic = dynamic 20 | self.horizon = horizon 21 | self.start_fc = nn.Linear(in_features=input_dim, out_features=self.channels) 22 | self.memory_size = memory_size 23 | 24 | if input_dim != 1: 25 | self.eval_dimin = nn.Linear(in_features=input_dim, out_features=1) 26 | 27 | self.layers = nn.ModuleList( 28 | [ 29 | Layer(device=device, input_dim=channels, dynamic=dynamic, num_nodes=num_nodes, cuts=12, 30 | cut_size=6, no_proxies=2, memory_size=memory_size, A_dict=A_dict, dataset_use=dataset_use, dataset_test=dataset_test, mode=mode), 31 | Layer(device=device, input_dim=channels, dynamic=dynamic, num_nodes=num_nodes, cuts=3, 32 | cut_size=4, no_proxies=2, memory_size=memory_size, A_dict=A_dict, dataset_use=dataset_use, dataset_test=dataset_test, mode=mode), 33 | Layer(device=device, input_dim=channels, dynamic=dynamic, num_nodes=num_nodes, cuts=1, 34 | cut_size=3, no_proxies=2, memory_size=memory_size, A_dict=A_dict, dataset_use=dataset_use, dataset_test=dataset_test, mode=mode), 35 | ]) 36 | 37 | self.skip_layers = nn.ModuleList([ 38 | nn.Linear(in_features=12 * channels, out_features=256), 39 | nn.Linear(in_features=3 * channels, out_features=256), 40 | nn.Linear(in_features=1 *channels, out_features=256), 41 | ]) 42 | 43 | self.projections = nn.Sequential(*[ 44 | nn.Linear(256, 512), 45 | nn.ReLU(), 46 | nn.Linear(512, horizon * self.output_dim)]) 47 | 48 | if self.dynamic: 49 | self.mu_estimator = nn.Sequential(*[ 50 | nn.Linear(lag, 32), 51 | nn.Tanh(), 52 | nn.Linear(32, 32), 53 | nn.Tanh(), 54 | nn.Linear(32, memory_size) 55 | ]) 56 | 57 | self.logvar_estimator = nn.Sequential(*[ 58 | nn.Linear(lag, 32), 59 | nn.Tanh(), 60 | nn.Linear(32, 32), 61 | nn.Tanh(), 62 | nn.Linear(32, memory_size) 63 | ]) 64 | 65 | def forward(self, x, select_dataset): 66 | if self.dynamic: 67 | if x.shape[-1] != 1: 68 | x_dm = self.eval_dimin(x) 69 | else: 70 | x_dm = x 71 | mu = self.mu_estimator(x_dm.transpose(3, 1).squeeze(1)) 72 | logvar = self.logvar_estimator(x_dm.transpose(3, 1).squeeze(1)) 73 | z_data = reparameterize(mu, logvar) 74 | else: 75 | z_data = 0 76 | 77 | 78 | x = self.start_fc(x) 79 | batch_size = x.size(0) 80 | num_nodes = x.shape[2] 81 | 82 | skip = 0 83 | for layer, skip_layer in zip(self.layers, self.skip_layers): 84 | x = layer(x, z_data, select_dataset) 85 | skip_inp = x.transpose(2, 1).reshape(batch_size, num_nodes, -1) 86 | skip = skip + skip_layer(skip_inp) 87 | 88 | x = torch.relu(skip) 89 | out = self.projections(x) 90 | 91 | if self.output_dim == 1: 92 | out = out.transpose(2, 1).unsqueeze(-1) 93 | else: 94 | out = out.unsqueeze(-1).reshape(batch_size, num_nodes, self.horizon, -1).transpose(2, 1) 95 | 96 | # print(out.shape) 97 | 98 | return out 99 | 100 | 101 | class Layer(nn.Module): 102 | def __init__(self, device, input_dim, num_nodes, cuts, cut_size, dynamic, memory_size, no_proxies, A_dict, dataset_use, dataset_test, mode): 103 | super(Layer, self).__init__() 104 | self.device = device 105 | self.input_dim = input_dim 106 | self.num_nodes = num_nodes 107 | self.dynamic = dynamic 108 | self.cuts = cuts 109 | self.cut_size = cut_size 110 | self.no_proxies = no_proxies 111 | self.mode = mode 112 | 113 | self.dataset2index = {} 114 | if mode == 'pretrain': 115 | self.proxies_pretrain = nn.ParameterList() 116 | self.mu_pretrain = nn.ParameterList() 117 | self.logvar_pretrain = nn.ParameterList() 118 | for i, data_graph in enumerate(dataset_use): 119 | self.dataset2index[data_graph] = i 120 | n_dataset = A_dict[data_graph].shape[0] 121 | self.proxies_pretrain.append(nn.Parameter(torch.randn(1, cuts * no_proxies, n_dataset, input_dim).to(device), 122 | requires_grad=True)) 123 | if self.dynamic: 124 | self.mu_pretrain.append(nn.Parameter(torch.randn(n_dataset, memory_size).to(device), requires_grad=True).to( 125 | device)) 126 | self.logvar_pretrain.append(nn.Parameter(torch.randn(n_dataset, memory_size).to(device), requires_grad=True).to( 127 | device)) 128 | else: 129 | self.proxies_eval = nn.ParameterList() 130 | self.mu_eval = nn.ParameterList() 131 | self.logvar_eval = nn.ParameterList() 132 | for i, data_graph in enumerate([dataset_test]): 133 | self.dataset2index[data_graph] = i 134 | n_dataset = A_dict[data_graph].shape[0] 135 | self.proxies_eval.append(nn.Parameter(torch.randn(1, cuts * no_proxies, n_dataset, input_dim).to(device), 136 | requires_grad=True).to(device)) 137 | 138 | if self.dynamic: 139 | self.mu_eval.append(nn.Parameter(torch.randn(n_dataset, memory_size).to(device), requires_grad=True).to( 140 | device)) 141 | self.logvar_eval.append(nn.Parameter(torch.randn(n_dataset, memory_size).to(device), requires_grad=True).to( 142 | device)) 143 | 144 | # self.proxies = nn.Parameter(torch.randn(1, cuts * no_proxies, self.num_nodes, input_dim).to(device), 145 | # requires_grad=True).to(device) 146 | 147 | self.temporal_att = TemporalAttention(input_dim, num_nodes=num_nodes, cut_size=cut_size) 148 | self.spatial_att = SpatialAttention(input_dim, num_nodes=num_nodes) 149 | 150 | # if self.dynamic: 151 | # self.mu = nn.Parameter(torch.randn(num_nodes, memory_size).to(device), requires_grad=True).to(device) 152 | # self.logvar = nn.Parameter(torch.randn(num_nodes, memory_size).to(device), requires_grad=True).to(device) 153 | # 154 | self.temporal_parameter_generators = nn.ModuleList([ 155 | ParameterGenerator(memory_size=memory_size, input_dim=input_dim, output_dim=input_dim, 156 | num_nodes=num_nodes, dynamic=dynamic) for _ in range(2) 157 | ]) 158 | 159 | self.spatial_parameter_generators = nn.ModuleList([ 160 | ParameterGenerator(memory_size=memory_size, input_dim=input_dim, output_dim=input_dim, 161 | num_nodes=num_nodes, dynamic=dynamic) for _ in range(2) 162 | ]) 163 | 164 | self.aggregator = nn.Sequential(*[ 165 | nn.Linear(input_dim, input_dim), 166 | nn.ReLU(), 167 | nn.Linear(input_dim, input_dim), 168 | nn.Sigmoid() 169 | ]) 170 | 171 | def forward(self, x, z_data, select_dataset): 172 | # x shape: B T N C 173 | batch_size = x.size(0) 174 | 175 | if self.dynamic: 176 | if self.mode == 'pretrain': 177 | z_sample = reparameterize(self.mu_pretrain[self.dataset2index[select_dataset]], self.logvar_pretrain[self.dataset2index[select_dataset]]) 178 | else: 179 | z_sample = reparameterize(self.mu_eval[self.dataset2index[select_dataset]], self.logvar_eval[self.dataset2index[select_dataset]]) 180 | # z_sample = reparameterize(self.mu, self.logvar) 181 | z_data = z_data + z_sample 182 | 183 | temporal_parameters = [layer(x, z_data) for layer in self.temporal_parameter_generators] 184 | spatial_parameters = [layer(x, z_data) for layer in self.spatial_parameter_generators] 185 | 186 | data_concat = [] 187 | out = 0 188 | for i in range(self.cuts): 189 | # shape is (B, cut_size, N, C) 190 | t = x[:, i * self.cut_size:(i + 1) * self.cut_size, :, :] 191 | 192 | if self.mode == 'pretrain': 193 | proxies = self.proxies_pretrain[self.dataset2index[select_dataset]][:, i * self.no_proxies: (i + 1) * self.no_proxies] 194 | else: 195 | proxies = self.proxies_eval[self.dataset2index[select_dataset]][:, i * self.no_proxies: (i + 1) * self.no_proxies] 196 | # proxies = self.proxies[:, i * self.no_proxies: (i + 1) * self.no_proxies] 197 | proxies = proxies.repeat(batch_size, 1, 1, 1) + out 198 | t = torch.cat([proxies, t], dim=1) 199 | 200 | out = self.temporal_att(t[:, :self.no_proxies, :, :], t, t, temporal_parameters) 201 | out = self.spatial_att(out, spatial_parameters) 202 | out = (self.aggregator(out) * out).sum(1, keepdim=True) 203 | data_concat.append(out) 204 | 205 | return torch.cat(data_concat, dim=1) 206 | 207 | class ParameterGenerator(nn.Module): 208 | def __init__(self, memory_size, input_dim, output_dim, num_nodes, dynamic): 209 | super(ParameterGenerator, self).__init__() 210 | self.input_dim = input_dim 211 | self.output_dim = output_dim 212 | self.num_nodes = num_nodes 213 | self.dynamic = dynamic 214 | 215 | if self.dynamic: 216 | print('Using DYNAMIC') 217 | self.weight_generator = nn.Sequential(*[ 218 | nn.Linear(memory_size, 32), 219 | nn.ReLU(), 220 | nn.Linear(32, 5), 221 | nn.ReLU(), 222 | nn.Linear(5, input_dim * output_dim) 223 | ]) 224 | self.bias_generator = nn.Sequential(*[ 225 | nn.Linear(memory_size, 32), 226 | nn.ReLU(), 227 | nn.Linear(32, 5), 228 | nn.ReLU(), 229 | nn.Linear(5, output_dim) 230 | ]) 231 | else: 232 | print('Using FC') 233 | self.weights = nn.Parameter(torch.rand(input_dim, output_dim), requires_grad=True) 234 | self.biases = nn.Parameter(torch.rand(input_dim), requires_grad=True) 235 | 236 | def forward(self, x, memory=None): 237 | if self.dynamic: 238 | # weights = self.weight_generator(memory).view(x.shape[0], self.num_nodes, self.input_dim, self.output_dim) 239 | # biases = self.bias_generator(memory).view(x.shape[0], self.num_nodes, self.output_dim) 240 | weights = self.weight_generator(memory).view(x.shape[0], -1, self.input_dim, self.output_dim) 241 | biases = self.bias_generator(memory).view(x.shape[0], -1, self.output_dim) 242 | else: 243 | weights = self.weights 244 | biases = self.biases 245 | return weights, biases 246 | -------------------------------------------------------------------------------- /model/ST_WA/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import configparser 4 | import torch 5 | from scipy.sparse.linalg import eigs 6 | from lib.predifineGraph import load_pickle, weight_matrix, get_adjacency_matrix 7 | import pandas as pd 8 | 9 | # def get_adjacency_matrix(distance_df_filename, num_of_vertices, 10 | # type_='connectivity', id_filename=None): 11 | # ''' 12 | # Parameters 13 | # ---------- 14 | # distance_df_filename: str, path of the csv file contains edges information 15 | # num_of_vertices: int, the number of vertices 16 | # type_: str, {connectivity, distance} 17 | # Returns 18 | # ---------- 19 | # A: np.ndarray, adjacency matrix 20 | # ''' 21 | # import csv 22 | # 23 | # A = np.zeros((int(num_of_vertices), int(num_of_vertices)), 24 | # dtype=np.float32) 25 | # 26 | # if id_filename != 'None': 27 | # with open(id_filename, 'r') as f: 28 | # id_dict = {int(i): idx 29 | # for idx, i in enumerate(f.read().strip().split('\n'))} 30 | # with open(distance_df_filename, 'r') as f: 31 | # f.readline() 32 | # reader = csv.reader(f) 33 | # for row in reader: 34 | # if len(row) != 3: 35 | # continue 36 | # i, j, distance = int(row[0]), int(row[1]), float(row[2]) 37 | # A[id_dict[i], id_dict[j]] = 1 38 | # A[id_dict[j], id_dict[i]] = 1 39 | # return A 40 | # 41 | # # Fills cells in the matrix with distances. 42 | # with open(distance_df_filename, 'r') as f: 43 | # f.readline() 44 | # reader = csv.reader(f) 45 | # for row in reader: 46 | # if len(row) != 3: 47 | # continue 48 | # i, j, distance = int(row[0]), int(row[1]), float(row[2]) 49 | # if type_ == 'connectivity': 50 | # A[i, j] = 1 51 | # A[j, i] = 1 52 | # elif type_ == 'distance': 53 | # A[i, j] = 1 / distance 54 | # A[j, i] = 1 / distance 55 | # else: 56 | # raise ValueError("type_ error, must be " 57 | # "connectivity or distance!") 58 | # return A 59 | 60 | 61 | 62 | def scaled_Laplacian(W): 63 | ''' 64 | compute \tilde{L} 65 | 66 | Parameters 67 | ---------- 68 | W: np.ndarray, shape is (N, N), N is the num of vertices 69 | 70 | Returns 71 | ---------- 72 | scaled_Laplacian: np.ndarray, shape (N, N) 73 | 74 | ''' 75 | 76 | assert W.shape[0] == W.shape[1] 77 | 78 | D = np.diag(np.sum(W, axis=1)) 79 | 80 | L = D - W 81 | 82 | lambda_max = eigs(L, k=1, which='LR')[0].real 83 | 84 | return (2 * L) / lambda_max - np.identity(W.shape[0]) 85 | 86 | def parse_args(DATASET): 87 | # get configuration 88 | config_file = '../conf/ST-WA/{}.conf'.format(DATASET) 89 | config = configparser.ConfigParser() 90 | config.read(config_file) 91 | 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument('--device', default=config['general']['device'], type=str) 94 | parser.add_argument('--data', default=DATASET, help='data path', type=str, ) 95 | # parser.add_argument('--adj_filename', type=str, default=config['data']['adj_filename']) 96 | parser.add_argument('--id_filename', type=str, default=config['data']['id_filename']) 97 | parser.add_argument('--val_ratio', type=float, default=config['data']['val_ratio']) 98 | parser.add_argument('--test_ratio', type=float, default=config['data']['test_ratio']) 99 | parser.add_argument('--num_nodes', type=int, default=config['data']['num_nodes']) 100 | parser.add_argument('--lag', type=int, default=config['data']['lag']) 101 | parser.add_argument('--horizon', type=int, default=config['data']['horizon']) 102 | 103 | parser.add_argument('--in_dim', type=int, default=config['model']['in_dim']) 104 | parser.add_argument('--out_dim', type=int, default=config['model']['out_dim']) 105 | parser.add_argument('--channels', type=int, default=config['model']['channels']) 106 | parser.add_argument('--dynamic', type=str, default=config['model']['dynamic']) 107 | parser.add_argument('--memory_size', type=int, default=config['model']['memory_size']) 108 | 109 | parser.add_argument('--column_wise', type=bool, default=False) 110 | 111 | args, _ = parser.parse_known_args() 112 | 113 | args.supports = None 114 | return args -------------------------------------------------------------------------------- /model/ST_WA/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class TemporalAttention(nn.Module): 6 | def __init__(self, in_dim, num_nodes=None, cut_size=0): 7 | super(TemporalAttention, self).__init__() 8 | self.K = 8 9 | 10 | if in_dim % self.K != 0: 11 | raise Exception('Hidden size is not divisible by the number of attention heads') 12 | 13 | self.head_size = int(in_dim // self.K) 14 | self.key_proj = LinearCustom() 15 | self.value_proj = LinearCustom() 16 | 17 | self.projection1 = nn.Linear(in_dim,in_dim) 18 | self.projection2 = nn.Linear(in_dim,in_dim) 19 | 20 | def forward(self, query, key, value, parameters): 21 | batch_size = query.shape[0] 22 | 23 | # [batch_size, num_step, N, K * head_size] 24 | key = self.key_proj(key, parameters[0]) 25 | value = self.value_proj(value, parameters[1]) 26 | 27 | # [K * batch_size, num_step, N, head_size] 28 | query = torch.cat(torch.split(query, self.head_size, dim=-1), dim=0) 29 | key = torch.cat(torch.split(key, self.head_size, dim=-1), dim=0) 30 | value = torch.cat(torch.split(value, self.head_size, dim=-1), dim=0) 31 | 32 | # query: [K * batch_size, N, 1, head_size] 33 | # key: [K * batch_size, N, head_size, num_step] 34 | # value: [K * batch_size, N, num_step, head_size] 35 | query = query.permute((0, 2, 1, 3)) 36 | key = key.permute((0, 2, 3, 1)) 37 | value = value.permute((0, 2, 1, 3)) 38 | 39 | attention = torch.matmul(query, key) # [K * batch_size, N, num_step, num_step] 40 | attention /= (self.head_size ** 0.5) 41 | 42 | # normalize the attention scores 43 | # attention = self.mask * attention 44 | attention = F.softmax(attention, dim=-1) 45 | 46 | x = torch.matmul(attention, value) # [batch_size * head_size, num_step, N, K] 47 | x = x.permute((0, 2, 1, 3)) 48 | x = torch.cat(torch.split(x, batch_size, dim=0), dim=-1) 49 | 50 | # projection 51 | x = self.projection1(x) 52 | x = F.tanh(x) 53 | x = self.projection2(x) 54 | return x 55 | 56 | 57 | class SpatialAttention(nn.Module): 58 | def __init__(self, in_dim, support=None, num_nodes=None): 59 | super(SpatialAttention, self).__init__() 60 | self.support = support 61 | self.K = 8 62 | 63 | if in_dim % self.K != 0: 64 | raise Exception('Hidden size is not divisible by the number of attention heads') 65 | 66 | self.head_size = int(in_dim // self.K) 67 | self.linear = LinearCustom() 68 | self.projection1 = nn.Linear(in_dim, in_dim) 69 | self.projection2 = nn.Linear(in_dim, in_dim) 70 | 71 | def forward(self, x, parameters): 72 | batch_size = x.shape[0] 73 | # [batch_size, 1, N, K * head_size] 74 | # query = self.linear(x, parameters[2]) 75 | key = self.linear(x, parameters[0]) 76 | value = self.linear(x, parameters[1]) 77 | 78 | # [K * batch_size, num_step, N, head_size] 79 | query = torch.cat(torch.split(x, self.head_size, dim=-1), dim=0) 80 | key = torch.cat(torch.split(key, self.head_size, dim=-1), dim=0) 81 | value = torch.cat(torch.split(value, self.head_size, dim=-1), dim=0) 82 | 83 | attention = torch.matmul(query, key.transpose(2, 3)) # [K * batch_size, N, num_step, num_step] 84 | attention /= (self.head_size ** 0.5) 85 | 86 | attention = F.softmax(attention, dim=-1) 87 | x = torch.matmul(attention, value) # [batch_size * head_size, num_step, N, K] 88 | x = torch.cat(torch.split(x, batch_size, dim=0), dim=-1) 89 | 90 | # projection 91 | x = self.projection1(x) 92 | x = F.relu(x) 93 | x = self.projection2(x) 94 | return x 95 | 96 | 97 | class LinearCustom(nn.Module): 98 | 99 | def __init__(self): 100 | super(LinearCustom, self).__init__() 101 | 102 | def forward(self, inputs, parameters): 103 | weights, biases = parameters[0], parameters[1] 104 | if len(weights.shape) > 3: 105 | return torch.matmul(inputs.unsqueeze(-2), weights.unsqueeze(1).repeat(1, inputs.shape[1], 1, 1, 1)).squeeze( 106 | -2) + biases.unsqueeze(1).repeat(1, inputs.shape[1], 1, 1) 107 | return torch.matmul(inputs, weights) + biases 108 | -------------------------------------------------------------------------------- /model/TGCN/TGCN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import torch 4 | import torch.nn as nn 5 | from logging import getLogger 6 | 7 | 8 | def calculate_normalized_laplacian(adj): 9 | """ 10 | A = A + I 11 | L = D^-1/2 A D^-1/2 12 | 13 | Args: 14 | adj: adj matrix 15 | 16 | Returns: 17 | np.ndarray: L 18 | """ 19 | adj = sp.coo_matrix(adj + sp.eye(adj.shape[0])) 20 | d = np.array(adj.sum(1)) 21 | d_inv_sqrt = np.power(d, -0.5).flatten() 22 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 23 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 24 | normalized_laplacian = adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 25 | return normalized_laplacian 26 | 27 | 28 | class TGCNCell(nn.Module): 29 | def __init__(self, num_units, adj_mx, num_nodes, device, input_dim=1): 30 | # ----------------------初始化参数---------------------------# 31 | super().__init__() 32 | self.num_units = num_units 33 | self.num_nodes = num_nodes 34 | self.input_dim = input_dim 35 | self._device = device 36 | self.act = torch.tanh 37 | 38 | # 这里提前构建好拉普拉斯 39 | support = calculate_normalized_laplacian(adj_mx) 40 | self.normalized_adj = self._build_sparse_matrix(support, self._device) 41 | self.init_params() 42 | 43 | def init_params(self, bias_start=0.0): 44 | input_size = self.input_dim + self.num_units 45 | weight_0 = torch.nn.Parameter(torch.empty((input_size, 2 * self.num_units), device=self._device)) 46 | bias_0 = torch.nn.Parameter(torch.empty(2 * self.num_units, device=self._device)) 47 | weight_1 = torch.nn.Parameter(torch.empty((input_size, self.num_units), device=self._device)) 48 | bias_1 = torch.nn.Parameter(torch.empty(self.num_units, device=self._device)) 49 | 50 | torch.nn.init.xavier_normal_(weight_0) 51 | torch.nn.init.xavier_normal_(weight_1) 52 | torch.nn.init.constant_(bias_0, bias_start) 53 | torch.nn.init.constant_(bias_1, bias_start) 54 | 55 | self.register_parameter(name='weights_0', param=weight_0) 56 | self.register_parameter(name='weights_1', param=weight_1) 57 | self.register_parameter(name='bias_0', param=bias_0) 58 | self.register_parameter(name='bias_1', param=bias_1) 59 | 60 | self.weigts = {weight_0.shape: weight_0, weight_1.shape: weight_1} 61 | self.biases = {bias_0.shape: bias_0, bias_1.shape: bias_1} 62 | 63 | @staticmethod 64 | def _build_sparse_matrix(lap, device): 65 | lap = lap.tocoo() 66 | indices = np.column_stack((lap.row, lap.col)) 67 | # this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L) 68 | indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))] 69 | lap = torch.sparse_coo_tensor(indices.T, lap.data, lap.shape, device=device) 70 | return lap 71 | 72 | def forward(self, inputs, state): 73 | """ 74 | Gated recurrent unit (GRU) with Graph Convolution. 75 | 76 | Args: 77 | inputs: shape (batch, self.num_nodes * self.dim) 78 | state: shape (batch, self.num_nodes * self.gru_units) 79 | 80 | Returns: 81 | torch.tensor: shape (B, num_nodes * gru_units) 82 | """ 83 | output_size = 2 * self.num_units 84 | value = torch.sigmoid( 85 | self._gc(inputs, state, output_size, bias_start=1.0)) # (batch_size, self.num_nodes, output_size) 86 | r, u = torch.split(tensor=value, split_size_or_sections=self.num_units, dim=-1) 87 | r = torch.reshape(r, (-1, self.num_nodes * self.num_units)) # (batch_size, self.num_nodes * self.gru_units) 88 | u = torch.reshape(u, (-1, self.num_nodes * self.num_units)) 89 | 90 | c = self.act(self._gc(inputs, r * state, self.num_units)) 91 | c = c.reshape(shape=(-1, self.num_nodes * self.num_units)) 92 | new_state = u * state + (1.0 - u) * c 93 | return new_state 94 | 95 | def _gc(self, inputs, state, output_size, bias_start=0.0): 96 | """ 97 | GCN 98 | 99 | Args: 100 | inputs: (batch, self.num_nodes * self.dim) 101 | state: (batch, self.num_nodes * self.gru_units) 102 | output_size: 103 | bias_start: 104 | 105 | Returns: 106 | torch.tensor: (B, num_nodes , output_size) 107 | """ 108 | batch_size = inputs.shape[0] 109 | inputs = torch.reshape(inputs, (batch_size, self.num_nodes, -1)) # (batch, self.num_nodes, self.dim) 110 | state = torch.reshape(state, (batch_size, self.num_nodes, -1)) # (batch, self.num_nodes, self.gru_units) 111 | inputs_and_state = torch.cat([inputs, state], dim=2) 112 | input_size = inputs_and_state.shape[2] 113 | 114 | x = inputs_and_state 115 | x0 = x.permute(1, 2, 0) # (num_nodes, dim+gru_units, batch) 116 | x0 = x0.reshape(shape=(self.num_nodes, -1)) 117 | 118 | x1 = torch.sparse.mm(self.normalized_adj.float(), x0.float()) # A * X 119 | 120 | x1 = x1.reshape(shape=(self.num_nodes, input_size, batch_size)) 121 | x1 = x1.permute(2, 0, 1) # (batch_size, self.num_nodes, input_size) 122 | x1 = x1.reshape(shape=(-1, input_size)) # (batch_size * self.num_nodes, input_size) 123 | 124 | weights = self.weigts[(input_size, output_size)] 125 | x1 = torch.matmul(x1, weights) # (batch_size * self.num_nodes, output_size) 126 | 127 | biases = self.biases[(output_size,)] 128 | x1 += biases 129 | 130 | x1 = x1.reshape(shape=(batch_size, self.num_nodes, output_size)) 131 | return x1 132 | 133 | 134 | class TGCN(nn.Module): 135 | def __init__(self, args, A_dict, dataset_test, device, dim_in): 136 | super(TGCN, self).__init__() 137 | self.adj_mx = A_dict[dataset_test].cpu().numpy() 138 | self.num_nodes = args.num_nodes 139 | self.input_dim = dim_in 140 | self.output_dim = args.output_dim 141 | self.gru_units = args.rnn_units 142 | self.lam = args.lam 143 | 144 | self.input_window = args.input_window 145 | self.output_window = args.output_window 146 | self.device = device 147 | 148 | 149 | # -------------------构造模型----------------------------- 150 | self.tgcn_model = TGCNCell(self.gru_units, self.adj_mx, self.num_nodes, self.device, self.input_dim) 151 | self.output_model = nn.Linear(self.gru_units, self.output_window * self.output_dim) 152 | 153 | def forward(self, source, select_dataset): 154 | """ 155 | Args: 156 | batch: a batch of input, 157 | batch['X']: shape (batch_size, input_window, num_nodes, input_dim) \n 158 | batch['y']: shape (batch_size, output_window, num_nodes, output_dim) \n 159 | 160 | Returns: 161 | torch.tensor: (batch_size, self.output_window, self.num_nodes, self.output_dim) 162 | """ 163 | inputs = source 164 | # labels = batch['y'] 165 | # print(inputs.shape) 166 | 167 | batch_size, input_window, num_nodes, input_dim = inputs.shape 168 | inputs = inputs.permute(1, 0, 2, 3).contiguous() # (input_window, batch_size, num_nodes, input_dim) 169 | inputs = inputs.view(self.input_window, batch_size, num_nodes * input_dim).to(self.device) 170 | 171 | state = torch.zeros(batch_size, self.num_nodes * self.gru_units).to(self.device) 172 | for t in range(input_window): 173 | state = self.tgcn_model(inputs[t], state) 174 | 175 | state = state.view(batch_size, self.num_nodes, self.gru_units) # (batch_size, self.num_nodes, self.gru_units) 176 | output = self.output_model(state) # (batch_size, self.num_nodes, self.output_window * self.output_dim) 177 | output = output.view(batch_size, self.num_nodes, self.output_window, self.output_dim) 178 | output = output.permute(0, 2, 1, 3) 179 | return output -------------------------------------------------------------------------------- /model/TGCN/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import configparser 4 | from lib.predifineGraph import get_adjacency_matrix, load_pickle, weight_matrix 5 | import torch 6 | import pandas as pd 7 | 8 | def parse_args(DATASET): 9 | # get configuration 10 | config_file = '../conf/TGCN/{}.conf'.format(DATASET) 11 | config = configparser.ConfigParser() 12 | config.read(config_file) 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | # data 17 | parser.add_argument('--num_nodes', type=int, default=config['data']['num_nodes']) 18 | parser.add_argument('--input_window', type=int, default=config['data']['input_window']) 19 | parser.add_argument('--output_window', type=int, default=config['data']['output_window']) 20 | # model 21 | parser.add_argument('--rnn_units', type=int, default=config['model']['rnn_units']) 22 | parser.add_argument('--lam', type=float, default=config['model']['lam']) 23 | parser.add_argument('--output_dim', type=int, default=config['model']['output_dim']) 24 | 25 | args, _ = parser.parse_known_args() 26 | args.adj_mx = None 27 | return args -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastdtw 2 | h5py 3 | numpy 4 | pandas 5 | scipy 6 | torchdiffeq 7 | tqdm 8 | tslearn 9 | --------------------------------------------------------------------------------