├── LICENSE ├── README.md ├── config ├── astgcn.json ├── dcrnn.json ├── graphwavenet.json ├── gru.json ├── stgcn.json └── traversenet.json ├── config_electricity ├── astgcn.json ├── dcrnn.json ├── graphwavenet.json ├── gru.json ├── stgcn.json └── traversenet.json ├── config_solar ├── astgcn.json ├── dcrnn.json ├── graphwavenet.json ├── gru.json ├── stgcn.json └── traversenet.json ├── dataset ├── __init__.py └── data.py ├── layers ├── __init__.py ├── dcrnn_cell.py ├── gat_layer.py ├── layernorm.py ├── mlp_layer.py └── smt.py ├── main.py ├── module ├── __init__.py ├── astgcn_block.py ├── dcrnn.py ├── stgcn_block.py └── traversebody.py ├── nets ├── __init__.py ├── astgcn_net.py ├── dcrnn_net.py ├── graphwavenet.py ├── stgcn_net.py └── traverse_net.py ├── proc_data.py ├── proc_new_data.py ├── requirements.txt ├── train_electricity.sh ├── train_solar.sh ├── trainer ├── __init__.py ├── ctrainer.py ├── rtrainer.py └── tg_trainer.py └── utils ├── __init__.py ├── metrics.py └── process.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Instructions 2 | 3 | ## Preparation 4 | 5 | 1. Download data from https://github.com/Davidham3/STSGCN 6 | 2. Make the data folder and move the downloaded dataset into the data folder 7 | 3. Pre-process data: 8 | ``` 9 | python proc_data.py 10 | ``` 11 | 12 | ## Training 13 | For our model 14 | ``` 15 | python main.py --config ./config/traversenet.json 16 | ``` 17 | For baseline models 18 | ``` 19 | python main.py --config ./config/astgcn.json 20 | python main.py --config ./config/dcrnn.json 21 | python main.py --config ./config/graphwavenet.json 22 | python main.py --config ./config/gru.json 23 | python main.py --config ./config/stgcn.json 24 | ``` 25 | -------------------------------------------------------------------------------- /config/astgcn.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "params": { 8 | "seed": 9, 9 | "epochs": 1, 10 | "runs": 1, 11 | "batch_size": 64, 12 | "lr": 0.001, 13 | "weight_decay": 0, 14 | "clip": 5, 15 | "print_every": 50, 16 | "dataset": "PEMS-D8", 17 | "out_dir": "save_astgcn" 18 | }, 19 | 20 | "net_params": { 21 | "model": "astgcn", 22 | "in_dim": 3, 23 | "nb_block": 2, 24 | "K" : 3, 25 | "nb_chev_filter": 64, 26 | "nb_time_filter": 64, 27 | "time_strides": 1, 28 | "num_nodes": 170, 29 | "seq_in_len": 12, 30 | "seq_out_len": 3 31 | } 32 | } 33 | 34 | -------------------------------------------------------------------------------- /config/dcrnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | 8 | "params": { 9 | "seed": 9, 10 | "epochs": 1, 11 | "runs": 1, 12 | "batch_size": 64, 13 | "lr": 0.001, 14 | "weight_decay": 0, 15 | "epsilon": 1.0e-3, 16 | "clip": 5, 17 | "print_every": 50, 18 | "dataset": "PEMS-D8", 19 | "out_dir": "save_dcrnn" 20 | }, 21 | 22 | "net_params": { 23 | "model": "dcrnn", 24 | "in_dim": 3, 25 | "dim": 64, 26 | "out_dim": 1, 27 | "num_layers": 2, 28 | "num_nodes": 170, 29 | "dropout": 0, 30 | "seq_in_len": 12, 31 | "seq_out_len": 12, 32 | "filter_type": "dual_random_walk", 33 | "use_curriculum_learning": true, 34 | "cl_decay_steps": 2000, 35 | "max_diffusion_step": 2 36 | } 37 | } 38 | 39 | -------------------------------------------------------------------------------- /config/graphwavenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | 8 | "params": { 9 | "seed": 9, 10 | "epochs": 1, 11 | "runs": 1, 12 | "batch_size": 64, 13 | "lr": 0.001, 14 | "weight_decay": 0.0001, 15 | "clip": 5, 16 | "print_every": 50, 17 | "dataset": "PEMS-08", 18 | "out_dir": "save_graphwavenet" 19 | }, 20 | 21 | "net_params": { 22 | "model": "graphwavenet", 23 | "in_dim": 3, 24 | "num_layers": 2, 25 | "blocks": 4, 26 | "gcn_depth": 2, 27 | "gcn_bool": true, 28 | "addaptadj": false, 29 | "num_nodes": 170, 30 | "dropout": 0, 31 | "seq_in_len": 12, 32 | "seq_out_len": 12, 33 | "residual_channels": 32, 34 | "skip_channels": 256, 35 | "end_channels": 512, 36 | "dilation_channels": 32, 37 | "node_dim": 40, 38 | "kernel_size": 2 39 | } 40 | } 41 | 42 | -------------------------------------------------------------------------------- /config/gru.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | 8 | "params": { 9 | "seed": 9, 10 | "epochs": 1, 11 | "runs": 1, 12 | "batch_size": 64, 13 | "lr": 0.001, 14 | "weight_decay": 0, 15 | "epsilon": 1.0e-3, 16 | "clip": 5, 17 | "print_every": 50, 18 | "dataset": "PEMS-D8", 19 | "out_dir": "save_gru" 20 | }, 21 | 22 | "net_params": { 23 | "model": "gru", 24 | "in_dim": 3, 25 | "dim": 64, 26 | "out_dim": 1, 27 | "num_layers": 2, 28 | "num_nodes": 170, 29 | "dropout": 0, 30 | "seq_in_len": 12, 31 | "seq_out_len": 12, 32 | "filter_type": "dual_random_walk", 33 | "use_curriculum_learning": true, 34 | "cl_decay_steps": 2000, 35 | "max_diffusion_step": 2 36 | } 37 | } 38 | 39 | -------------------------------------------------------------------------------- /config/stgcn.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "params": { 8 | "seed": 9, 9 | "epochs": 1, 10 | "runs": 1, 11 | "batch_size": 64, 12 | "lr": 0.001, 13 | "weight_decay": 0, 14 | "clip": 5, 15 | "print_every": 50, 16 | "dataset": "PEMS-D8", 17 | "out_dir": "save_stgcn" 18 | }, 19 | 20 | "net_params": { 21 | "model": "stgcn", 22 | "in_dim": 3, 23 | "num_nodes": 170, 24 | "seq_in_len": 12, 25 | "seq_out_len": 12 26 | } 27 | } 28 | 29 | -------------------------------------------------------------------------------- /config/traversenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "params": { 8 | "seed": 9, 9 | "epochs": 1, 10 | "runs": 1, 11 | "batch_size": 64, 12 | "lr": 0.001, 13 | "weight_decay": 0, 14 | "clip": 5, 15 | "print_every": 50, 16 | "dataset": "PEMS-D8", 17 | "graph_path": "data/PEMS-D8-Gt.pkl", 18 | "out_dir": "save_tgnet" 19 | }, 20 | 21 | "net_params": { 22 | "model": "traversenet", 23 | "in_dim": 3, 24 | "dim": 64, 25 | "num_layers": 3, 26 | "heads": 1, 27 | "num_nodes": 170, 28 | "dropout": 0.0, 29 | "seq_in_len": 12, 30 | "seq_out_len": 12, 31 | "cl_decay_steps": 500 32 | } 33 | } -------------------------------------------------------------------------------- /config_electricity/astgcn.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "params": { 8 | "seed": 9, 9 | "epochs": 50, 10 | "runs": 5, 11 | "batch_size": 64, 12 | "lr": 0.001, 13 | "weight_decay": 0, 14 | "clip": 5, 15 | "print_every": 50, 16 | "dataset": "electricity", 17 | "out_dir": "save_astgcn" 18 | }, 19 | 20 | "net_params": { 21 | "model": "astgcn", 22 | "in_dim": 1, 23 | "nb_block": 2, 24 | "K" : 3, 25 | "nb_chev_filter": 64, 26 | "nb_time_filter": 64, 27 | "time_strides": 1, 28 | "num_nodes": 321, 29 | "seq_in_len": 12, 30 | "seq_out_len": 3 31 | } 32 | } 33 | 34 | -------------------------------------------------------------------------------- /config_electricity/dcrnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | 8 | "params": { 9 | "seed": 9, 10 | "epochs": 50, 11 | "runs": 5, 12 | "batch_size": 64, 13 | "lr": 0.01, 14 | "weight_decay": 0, 15 | "epsilon": 1.0e-3, 16 | "clip": 5, 17 | "print_every": 50, 18 | "dataset": "electricity", 19 | "out_dir": "save_dcrnn" 20 | }, 21 | 22 | "net_params": { 23 | "model": "dcrnn", 24 | "in_dim": 1, 25 | "dim": 64, 26 | "out_dim": 1, 27 | "num_layers": 2, 28 | "num_nodes": 321, 29 | "dropout": 0, 30 | "seq_in_len": 12, 31 | "seq_out_len": 12, 32 | "filter_type": "dual_random_walk", 33 | "use_curriculum_learning": true, 34 | "cl_decay_steps": 2000, 35 | "max_diffusion_step": 2 36 | } 37 | } 38 | 39 | -------------------------------------------------------------------------------- /config_electricity/graphwavenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | 8 | "params": { 9 | "seed": 9, 10 | "epochs": 50, 11 | "runs": 5, 12 | "batch_size": 64, 13 | "lr": 0.001, 14 | "weight_decay": 0.0001, 15 | "clip": 5, 16 | "print_every": 50, 17 | "dataset": "electricity", 18 | "out_dir": "save_graphwavenet" 19 | }, 20 | 21 | "net_params": { 22 | "model": "graphwavenet", 23 | "in_dim": 1, 24 | "num_layers": 2, 25 | "blocks": 4, 26 | "gcn_depth": 2, 27 | "gcn_bool": true, 28 | "addaptadj": false, 29 | "num_nodes": 321, 30 | "dropout": 0.3, 31 | "seq_in_len": 12, 32 | "seq_out_len": 12, 33 | "residual_channels": 32, 34 | "skip_channels": 256, 35 | "end_channels": 512, 36 | "dilation_channels": 32, 37 | "node_dim": 40, 38 | "kernel_size": 2 39 | } 40 | } 41 | 42 | -------------------------------------------------------------------------------- /config_electricity/gru.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | 8 | "params": { 9 | "seed": 9, 10 | "epochs": 50, 11 | "runs": 5, 12 | "batch_size": 64, 13 | "lr": 0.01, 14 | "weight_decay": 0, 15 | "epsilon": 1.0e-3, 16 | "clip": 5, 17 | "print_every": 50, 18 | "dataset": "electricity", 19 | "out_dir": "save_gru" 20 | }, 21 | 22 | "net_params": { 23 | "model": "gru", 24 | "in_dim": 1, 25 | "dim": 64, 26 | "out_dim": 1, 27 | "num_layers": 2, 28 | "num_nodes": 321, 29 | "dropout": 0, 30 | "seq_in_len": 12, 31 | "seq_out_len": 12, 32 | "filter_type": "dual_random_walk", 33 | "use_curriculum_learning": true, 34 | "cl_decay_steps": 2000, 35 | "max_diffusion_step": 2 36 | } 37 | } 38 | 39 | -------------------------------------------------------------------------------- /config_electricity/stgcn.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "params": { 8 | "seed": 9, 9 | "epochs": 50, 10 | "runs": 5, 11 | "batch_size": 64, 12 | "lr": 0.001, 13 | "weight_decay": 0, 14 | "clip": 5, 15 | "print_every": 50, 16 | "dataset": "electricity", 17 | "out_dir": "save_stgcn" 18 | }, 19 | 20 | "net_params": { 21 | "model": "stgcn", 22 | "in_dim": 1, 23 | "num_nodes": 321, 24 | "seq_in_len": 12, 25 | "seq_out_len": 12 26 | } 27 | } 28 | 29 | -------------------------------------------------------------------------------- /config_electricity/traversenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "params": { 8 | "seed": 9, 9 | "epochs": 50, 10 | "runs": 5, 11 | "batch_size": 64, 12 | "lr": 0.001, 13 | "weight_decay": 0, 14 | "clip": 5, 15 | "print_every": 50, 16 | "dataset": "electricity", 17 | "graph_path": "data/PEMS-D8-Gt.pkl", 18 | "out_dir": "save_tgnet" 19 | }, 20 | 21 | "net_params": { 22 | "model": "traversenet", 23 | "in_dim": 1, 24 | "dim": 64, 25 | "num_layers": 3, 26 | "heads": 1, 27 | "num_nodes": 321, 28 | "dropout": 0.0, 29 | "seq_in_len": 12, 30 | "seq_out_len": 12, 31 | "cl_decay_steps": 500 32 | } 33 | } -------------------------------------------------------------------------------- /config_solar/astgcn.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "params": { 8 | "seed": 9, 9 | "epochs": 50, 10 | "runs": 5, 11 | "batch_size": 64, 12 | "lr": 0.001, 13 | "weight_decay": 0, 14 | "clip": 5, 15 | "print_every": 50, 16 | "dataset": "solar", 17 | "out_dir": "save_astgcn" 18 | }, 19 | 20 | "net_params": { 21 | "model": "astgcn", 22 | "in_dim": 1, 23 | "nb_block": 2, 24 | "K" : 3, 25 | "nb_chev_filter": 64, 26 | "nb_time_filter": 64, 27 | "time_strides": 1, 28 | "num_nodes": 137, 29 | "seq_in_len": 12, 30 | "seq_out_len": 12 31 | } 32 | } 33 | 34 | -------------------------------------------------------------------------------- /config_solar/dcrnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | 8 | "params": { 9 | "seed": 9, 10 | "epochs": 50, 11 | "runs": 5, 12 | "batch_size": 64, 13 | "lr": 0.001, 14 | "weight_decay": 0, 15 | "epsilon": 1.0e-3, 16 | "clip": 5, 17 | "print_every": 50, 18 | "dataset": "solar", 19 | "out_dir": "save_dcrnn" 20 | }, 21 | 22 | "net_params": { 23 | "model": "dcrnn", 24 | "in_dim": 1, 25 | "dim": 64, 26 | "out_dim": 1, 27 | "num_layers": 2, 28 | "num_nodes": 137, 29 | "dropout": 0, 30 | "seq_in_len": 12, 31 | "seq_out_len": 12, 32 | "filter_type": "dual_random_walk", 33 | "use_curriculum_learning": true, 34 | "cl_decay_steps": 2000, 35 | "max_diffusion_step": 2 36 | } 37 | } 38 | 39 | -------------------------------------------------------------------------------- /config_solar/graphwavenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | 8 | "params": { 9 | "seed": 9, 10 | "epochs": 50, 11 | "runs": 5, 12 | "batch_size": 64, 13 | "lr": 0.001, 14 | "weight_decay": 0.0001, 15 | "clip": 5, 16 | "print_every": 50, 17 | "dataset": "solar", 18 | "out_dir": "save_graphwavenet" 19 | }, 20 | 21 | "net_params": { 22 | "model": "graphwavenet", 23 | "in_dim": 1, 24 | "num_layers": 2, 25 | "blocks": 4, 26 | "gcn_depth": 2, 27 | "gcn_bool": true, 28 | "addaptadj": false, 29 | "num_nodes": 137, 30 | "dropout": 0, 31 | "seq_in_len": 12, 32 | "seq_out_len": 12, 33 | "residual_channels": 32, 34 | "skip_channels": 256, 35 | "end_channels": 512, 36 | "dilation_channels": 32, 37 | "node_dim": 40, 38 | "kernel_size": 2 39 | } 40 | } 41 | 42 | -------------------------------------------------------------------------------- /config_solar/gru.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | 8 | "params": { 9 | "seed": 9, 10 | "epochs": 50, 11 | "runs": 5, 12 | "batch_size": 64, 13 | "lr": 0.001, 14 | "weight_decay": 0, 15 | "epsilon": 1.0e-3, 16 | "clip": 5, 17 | "print_every": 50, 18 | "dataset": "solar", 19 | "out_dir": "save_gru" 20 | }, 21 | 22 | "net_params": { 23 | "model": "gru", 24 | "in_dim": 1, 25 | "dim": 64, 26 | "out_dim": 1, 27 | "num_layers": 2, 28 | "num_nodes": 137, 29 | "dropout": 0, 30 | "seq_in_len": 12, 31 | "seq_out_len": 12, 32 | "filter_type": "dual_random_walk", 33 | "use_curriculum_learning": true, 34 | "cl_decay_steps": 2000, 35 | "max_diffusion_step": 2 36 | } 37 | } 38 | 39 | -------------------------------------------------------------------------------- /config_solar/stgcn.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "params": { 8 | "seed": 9, 9 | "epochs": 50, 10 | "runs": 5, 11 | "batch_size": 64, 12 | "lr": 0.001, 13 | "weight_decay": 0, 14 | "clip": 5, 15 | "print_every": 50, 16 | "dataset": "solar", 17 | "out_dir": "save_stgcn" 18 | }, 19 | 20 | "net_params": { 21 | "model": "stgcn", 22 | "in_dim": 1, 23 | "num_nodes": 137, 24 | "seq_in_len": 12, 25 | "seq_out_len": 12 26 | } 27 | } 28 | 29 | -------------------------------------------------------------------------------- /config_solar/traversenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": { 3 | "use": true, 4 | "id": 0 5 | }, 6 | 7 | "params": { 8 | "seed": 9, 9 | "epochs": 50, 10 | "runs": 5, 11 | "batch_size": 64, 12 | "lr": 0.001, 13 | "weight_decay": 0, 14 | "clip": 5, 15 | "print_every": 50, 16 | "dataset": "solar", 17 | "graph_path": "data/solar-Gt.pkl", 18 | "out_dir": "save_tgnet" 19 | }, 20 | 21 | "net_params": { 22 | "model": "traversenet", 23 | "in_dim": 1, 24 | "dim": 64, 25 | "num_layers": 3, 26 | "heads": 1, 27 | "num_nodes": 137, 28 | "dropout": 0.0, 29 | "seq_in_len": 12, 30 | "seq_out_len": 12, 31 | "cl_decay_steps": 500 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnzhan/TraverseNet/ba4ce7478386cb478293f5283a94c40bacdec0cc/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pickle 4 | from utils.process import * 5 | import json 6 | 7 | class StandardScaler(): 8 | """ 9 | Standard the input 10 | """ 11 | def __init__(self, mean, std): 12 | self.mean = mean 13 | self.std = std 14 | def transform(self, data): 15 | return (data - self.mean.to(data.device)) / self.std.to(data.device) 16 | def inverse_transform(self, data, dim): 17 | return (data * self.std[...,dim].item()) + self.mean[...,dim].item() 18 | 19 | class DataLoader(object): 20 | def __init__(self, xs, ys, batch_size, pad_with_last_sample=True): 21 | """ 22 | :param xs: 23 | :param ys: 24 | :param batch_size: 25 | :param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size. 26 | """ 27 | self.batch_size = batch_size 28 | self.current_ind = 0 29 | if pad_with_last_sample: 30 | num_padding = (batch_size - (len(xs) % batch_size)) % batch_size 31 | x_padding = np.repeat(xs[-1:], num_padding, axis=0) 32 | y_padding = np.repeat(ys[-1:], num_padding, axis=0) 33 | xs = np.concatenate([xs, x_padding], axis=0) 34 | ys = np.concatenate([ys, y_padding], axis=0) 35 | 36 | self.size = len(xs) 37 | self.num_batch = int(self.size // self.batch_size) 38 | self.xs = xs 39 | self.ys = ys 40 | 41 | def shuffle(self): 42 | permutation = np.random.permutation(self.size) 43 | xs, ys = self.xs[permutation], self.ys[permutation] 44 | self.xs = xs 45 | self.ys = ys 46 | 47 | def get_iterator(self): 48 | self.current_ind = 0 49 | def _wrapper(): 50 | while self.current_ind < self.num_batch: 51 | start_ind = self.batch_size * self.current_ind 52 | end_ind = min(self.size, self.batch_size * (self.current_ind + 1)) 53 | x_i = self.xs[start_ind: end_ind, ...] 54 | y_i = self.ys[start_ind: end_ind, ...] 55 | yield (x_i, y_i) 56 | self.current_ind += 1 57 | return _wrapper() 58 | 59 | class PemsData: 60 | def __init__(self, num_nodes, path, adjpath, idpath=None): 61 | self.num_nodes = num_nodes 62 | self.path = path 63 | self.adjpath = adjpath 64 | self.idpath = idpath 65 | 66 | def load(self): 67 | data = np.load(self.path) 68 | return data['data'] 69 | 70 | def prcoess(self, savepath): 71 | data = {} 72 | x = self.load() 73 | x = x.transpose() 74 | x = torch.Tensor(x) 75 | length = x.shape[2] 76 | trainx = [] 77 | trainy = [] 78 | valx = [] 79 | valy = [] 80 | testx = [] 81 | testy = [] 82 | 83 | x = x.unsqueeze(dim=0) 84 | 85 | for i in range(int(length*0.6)-24): 86 | tx = x[...,i:i+12] 87 | ty = x[...,i+12:i+24] 88 | trainx.append(tx) 89 | trainy.append(ty) 90 | for i in range(int(length*0.6),int(length*0.8)-24): 91 | tx = x[...,i:i+12] 92 | ty = x[...,i+12:i+24] 93 | valx.append(tx) 94 | valy.append(ty) 95 | for i in range(int(length*0.8), length-24): 96 | tx = x[...,i:i+12] 97 | ty = x[...,i+12:i+24] 98 | testx.append(tx) 99 | testy.append(ty) 100 | 101 | trainx = torch.cat(trainx,dim=0) 102 | trainx = trainx.transpose(3,1) 103 | trainy = torch.cat(trainy,dim=0) 104 | trainy = trainy.transpose(3,1) 105 | 106 | valx = torch.cat(valx,dim=0) 107 | valx = valx.transpose(3,1) 108 | valy = torch.cat(valy,dim=0) 109 | valy = valy.transpose(3,1) 110 | 111 | testx = torch.cat(testx,dim=0) 112 | testx = testx.transpose(3,1) 113 | testy = torch.cat(testy,dim=0) 114 | testy = testy.transpose(3,1) 115 | 116 | data['x_train'] = trainx 117 | data['y_train'] = trainy 118 | data['x_val'] = valx 119 | data['y_val'] = valy 120 | data['x_test'] = testx 121 | data['y_test'] = testy 122 | 123 | data['adj'], srclist, tgtlist, distlist = self.load_graph(x[...,0:int(length*0.6)-1].squeeze()) 124 | file = open(savepath, "wb") 125 | pickle.dump(data, file) 126 | return data['adj'], srclist, tgtlist, distlist 127 | 128 | def cos(self, x1, x2): 129 | x1 = x1 - torch.mean(x1) 130 | x2 = x2 - torch.mean(x2) 131 | n = torch.sum(x1*x2) 132 | d1 = torch.sum(x1**2)**0.5 133 | d2 = torch.sum(x2**2)**0.5 134 | out = n/(d1*d2) 135 | return out 136 | 137 | def load_graph(self, x): 138 | adj = torch.zeros(self.num_nodes,self.num_nodes) 139 | srclist = [] 140 | tgtlist = [] 141 | dislist = [] 142 | if "08" in self.path: 143 | thr = 0.995 144 | elif "04" in self.path: 145 | thr = 0.995 146 | else: 147 | thr = 0.995 148 | print(x.shape) 149 | for i in range(self.num_nodes): 150 | for j in range(self.num_nodes): 151 | w = self.cos(x[0, i, :], x[0, j, :]) 152 | if i == j: 153 | adj[i, j] = 1 154 | srclist.append(i) 155 | tgtlist.append(j) 156 | dislist.append(0) 157 | continue 158 | if w >= thr or w<=-thr: 159 | adj[i, j] = w 160 | srclist.append(i) 161 | tgtlist.append(j) 162 | dislist.append(w.item()) 163 | return adj, srclist, tgtlist, dislist 164 | 165 | def load_graph1(self): 166 | node2id = dict() 167 | if self.idpath is not None: 168 | file = open(self.idpath) 169 | id = 0 170 | for li in file: 171 | li = li.strip() 172 | node2id[int(li)] = id 173 | id += 1 174 | 175 | file = open(self.adjpath) 176 | nodes = [i for i in range(self.num_nodes)] 177 | dist = [0 for i in range(self.num_nodes)] 178 | adj = torch.eye(self.num_nodes) 179 | 180 | for li in file: 181 | li = li.strip().split(',') 182 | try: 183 | li = [float(t) for t in li] 184 | except Exception: 185 | continue 186 | if self.idpath is not None: 187 | src = int(node2id[li[0]]) 188 | tgt = int(node2id[li[1]]) 189 | else: 190 | src = int(li[0]) 191 | tgt = int(li[1]) 192 | if src != tgt: 193 | adj[src, tgt] = li[2] 194 | 195 | srclist = [] 196 | tgtlist = [] 197 | dislist = [] 198 | for i in range(self.num_nodes): 199 | for j in range(self.num_nodes): 200 | if adj[i, j] > 1e-9 and i!=j: 201 | srclist.append(i) 202 | tgtlist.append(j) 203 | dislist.append(adj[i, j].item()) 204 | return adj, nodes+srclist, nodes+tgtlist, dist+dislist 205 | 206 | 207 | class PoxData: 208 | def __init__(self, path, xkey): 209 | self.path = path 210 | file = open(self.path) 211 | self.store = json.load(file) 212 | self.num_nodes = len(self.store['X']) 213 | self.xkey = xkey 214 | file.close() 215 | 216 | def prcoess(self, savepath): 217 | data = {} 218 | x = np.array(self.store[self.xkey]) 219 | x = x.transpose() 220 | x = torch.Tensor(x) 221 | x = x.unsqueeze(0) 222 | length = x.shape[2] 223 | trainx = [] 224 | trainy = [] 225 | valx = [] 226 | valy = [] 227 | testx = [] 228 | testy = [] 229 | 230 | x = x.unsqueeze(dim=0) 231 | 232 | for i in range(int(length*0.6)-24): 233 | tx = x[...,i:i+12] 234 | ty = x[...,i+12:i+24] 235 | trainx.append(tx) 236 | trainy.append(ty) 237 | for i in range(int(length*0.6),int(length*0.8)-24): 238 | tx = x[...,i:i+12] 239 | ty = x[...,i+12:i+24] 240 | valx.append(tx) 241 | valy.append(ty) 242 | for i in range(int(length*0.8), length-24): 243 | tx = x[...,i:i+12] 244 | ty = x[...,i+12:i+24] 245 | testx.append(tx) 246 | testy.append(ty) 247 | 248 | trainx = torch.cat(trainx,dim=0) 249 | trainx = trainx.transpose(3,1) 250 | trainy = torch.cat(trainy,dim=0) 251 | trainy = trainy.transpose(3,1) 252 | 253 | valx = torch.cat(valx,dim=0) 254 | valx = valx.transpose(3,1) 255 | valy = torch.cat(valy,dim=0) 256 | valy = valy.transpose(3,1) 257 | 258 | testx = torch.cat(testx,dim=0) 259 | testx = testx.transpose(3,1) 260 | testy = torch.cat(testy,dim=0) 261 | testy = testy.transpose(3,1) 262 | 263 | data['x_train'] = trainx 264 | data['y_train'] = trainy 265 | data['x_val'] = valx 266 | data['y_val'] = valy 267 | data['x_test'] = testx 268 | data['y_test'] = testy 269 | 270 | data['adj'], srclist, tgtlist, distlist = self.load_graph() 271 | file = open(savepath, "wb") 272 | pickle.dump(data, file) 273 | 274 | def cos(self, x1, x2): 275 | x1 = x1 - torch.mean(x1) 276 | x2 = x2 - torch.mean(x2) 277 | n = torch.sum(x1*x2) 278 | d1 = torch.sum(x1**2)**0.5 279 | d2 = torch.sum(x2**2)**0.5 280 | out = n/(d1*d2) 281 | return out 282 | 283 | def load_graph(self): 284 | nodes = [i for i in range(self.num_nodes)] 285 | dist = [0 for i in range(self.num_nodes)] 286 | adj = torch.eye(self.num_nodes) 287 | for i in range(len(self.store["edges"])): 288 | src = self.store["edges"][i][0] 289 | tgt = self.store["edges"][i][1] 290 | if src != tgt: 291 | if "weights" not in self.store.keys(): 292 | adj[src, tgt] = 1 293 | else: 294 | adj[src, tgt] = self.store["weights"][i] 295 | 296 | srclist = [] 297 | tgtlist = [] 298 | dislist = [] 299 | for i in range(self.num_nodes): 300 | for j in range(self.num_nodes): 301 | if adj[i, j] > 1e-9 and i!=j: 302 | srclist.append(i) 303 | tgtlist.append(j) 304 | dislist.append(adj[i, j].item()) 305 | return adj, nodes+srclist, nodes+tgtlist, dist+dislist 306 | 307 | 308 | class MulData: 309 | def __init__(self, path): 310 | self.path = path 311 | self.x = np.loadtxt(path, delimiter=",") 312 | self.num_nodes = self.x.shape[1] 313 | 314 | def prcoess(self, savepath): 315 | data = {} 316 | x = self.x 317 | x = x.transpose() 318 | x = torch.Tensor(x) 319 | x = x.unsqueeze(0) 320 | length = x.shape[2] 321 | trainx = [] 322 | trainy = [] 323 | valx = [] 324 | valy = [] 325 | testx = [] 326 | testy = [] 327 | 328 | x = x.unsqueeze(dim=0) 329 | 330 | for i in range(int(length*0.6)-24): 331 | tx = x[...,i:i+12] 332 | ty = x[...,i+12:i+24] 333 | trainx.append(tx) 334 | trainy.append(ty) 335 | for i in range(int(length*0.6),int(length*0.8)-24): 336 | tx = x[...,i:i+12] 337 | ty = x[...,i+12:i+24] 338 | valx.append(tx) 339 | valy.append(ty) 340 | for i in range(int(length*0.8), length-24): 341 | tx = x[...,i:i+12] 342 | ty = x[...,i+12:i+24] 343 | testx.append(tx) 344 | testy.append(ty) 345 | 346 | trainx = torch.cat(trainx,dim=0) 347 | trainx = trainx.transpose(3,1) 348 | trainy = torch.cat(trainy,dim=0) 349 | trainy = trainy.transpose(3,1) 350 | 351 | valx = torch.cat(valx,dim=0) 352 | valx = valx.transpose(3,1) 353 | valy = torch.cat(valy,dim=0) 354 | valy = valy.transpose(3,1) 355 | 356 | testx = torch.cat(testx,dim=0) 357 | testx = testx.transpose(3,1) 358 | testy = torch.cat(testy,dim=0) 359 | testy = testy.transpose(3,1) 360 | 361 | data['x_train'] = trainx 362 | data['y_train'] = trainy 363 | data['x_val'] = valx 364 | data['y_val'] = valy 365 | data['x_test'] = testx 366 | data['y_test'] = testy 367 | data['adj'], srclist, tgtlist, distlist = self.load_graph(x[...,0:int(length*0.6)-1].squeeze()) 368 | file = open(savepath, "wb") 369 | pickle.dump(data, file) 370 | return data['adj'], srclist, tgtlist, distlist 371 | 372 | def cos(self, x1, x2): 373 | x1 = x1 - torch.mean(x1) 374 | x2 = x2 - torch.mean(x2) 375 | n = torch.sum(x1*x2) 376 | d1 = torch.sum(x1**2)**0.5 377 | d2 = torch.sum(x2**2)**0.5 378 | out = n/(d1*d2) 379 | return out 380 | 381 | def load_graph(self, x): 382 | adj = torch.zeros(self.num_nodes,self.num_nodes) 383 | srclist = [] 384 | tgtlist = [] 385 | dislist = [] 386 | if "solar" in self.path: 387 | thr = 0.975 388 | elif "electricity" in self.path: 389 | thr = 0.93 390 | else: 391 | thr = 0 392 | for i in range(self.num_nodes): 393 | for j in range(self.num_nodes): 394 | w = self.cos(x[i, :], x[j, :]) 395 | if i == j: 396 | adj[i, j] = 1 397 | srclist.append(i) 398 | tgtlist.append(j) 399 | dislist.append(0) 400 | continue 401 | if w >= thr or w<=-thr: 402 | adj[i, j] = w 403 | srclist.append(i) 404 | tgtlist.append(j) 405 | dislist.append(w.item()) 406 | return adj, srclist, tgtlist, dislist 407 | 408 | 409 | class WindmillData: 410 | def __init__(self, path): 411 | self.path = path 412 | file = open(self.path) 413 | self.store = json.load(file) 414 | self.num_nodes = len(self.store['block'][0]) 415 | file.close() 416 | 417 | def prcoess(self, savepath): 418 | data = {} 419 | x = np.array(self.store['block']) 420 | x = x.transpose() 421 | x = torch.Tensor(x) 422 | x = x.unsqueeze(0) 423 | length = x.shape[2] 424 | trainx = [] 425 | trainy = [] 426 | valx = [] 427 | valy = [] 428 | testx = [] 429 | testy = [] 430 | 431 | x = x.unsqueeze(dim=0) 432 | 433 | for i in range(int(length*0.6)-24): 434 | tx = x[...,i:i+12] 435 | ty = x[...,i+12:i+24] 436 | trainx.append(tx) 437 | trainy.append(ty) 438 | for i in range(int(length*0.6),int(length*0.8)-24): 439 | tx = x[...,i:i+12] 440 | ty = x[...,i+12:i+24] 441 | valx.append(tx) 442 | valy.append(ty) 443 | for i in range(int(length*0.8), length-24): 444 | tx = x[...,i:i+12] 445 | ty = x[...,i+12:i+24] 446 | testx.append(tx) 447 | testy.append(ty) 448 | 449 | trainx = torch.cat(trainx,dim=0) 450 | trainx = trainx.transpose(3,1) 451 | trainy = torch.cat(trainy,dim=0) 452 | trainy = trainy.transpose(3,1) 453 | 454 | valx = torch.cat(valx,dim=0) 455 | valx = valx.transpose(3,1) 456 | valy = torch.cat(valy,dim=0) 457 | valy = valy.transpose(3,1) 458 | 459 | testx = torch.cat(testx,dim=0) 460 | testx = testx.transpose(3,1) 461 | testy = torch.cat(testy,dim=0) 462 | testy = testy.transpose(3,1) 463 | 464 | data['x_train'] = trainx 465 | data['y_train'] = trainy 466 | data['x_val'] = valx 467 | data['y_val'] = valy 468 | data['x_test'] = testx 469 | data['y_test'] = testy 470 | 471 | data['adj'], srclist, tgtlist, distlist = self.load_graph() 472 | file = open(savepath, "wb") 473 | pickle.dump(data, file) 474 | 475 | def load_graph(self): 476 | nodes = [i for i in range(self.num_nodes)] 477 | dist = [0 for i in range(self.num_nodes)] 478 | adj = torch.zeros(self.num_nodes,self.num_nodes) 479 | for i in range(len(self.store["edges"])): 480 | src = self.store["edges"][i][0] 481 | tgt = self.store["edges"][i][1] 482 | if src != tgt: 483 | if "weights" not in self.store.keys(): 484 | adj[src, tgt] = 1 485 | else: 486 | adj[src, tgt] = self.store["weights"][i] 487 | 488 | srclist = [] 489 | tgtlist = [] 490 | dislist = [] 491 | for i in range(self.num_nodes): 492 | for j in range(self.num_nodes): 493 | if adj[i, j] > 1e-9 and i!=j: 494 | srclist.append(i) 495 | tgtlist.append(j) 496 | dislist.append(adj[i, j].item()) 497 | return adj, nodes+srclist, nodes+tgtlist, dist+dislist 498 | 499 | 500 | def load_data(batch_size, path, device=None, normalize=True): 501 | file = open(path, "rb") 502 | data = pickle.load(file) 503 | mean = data['x_train'].mean(axis=(0, 1, 2), keepdims=True) 504 | std = data['x_train'].std(axis=(0, 1, 2), keepdims=True) 505 | if normalize: 506 | scaler = StandardScaler(mean=mean, std=std) 507 | for category in ['train', 'val', 'test']: 508 | data['x_' + category] = scaler.transform(data['x_' + category]) 509 | data['y_' + category] = scaler.transform(data['y_' + category]) 510 | else: 511 | scaler = StandardScaler(mean=0, std=1) 512 | 513 | data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size) 514 | data['val_loader'] = DataLoader(data['x_val'], data['y_val'], batch_size) 515 | data['test_loader'] = DataLoader(data['x_test'], data['y_test'], batch_size) 516 | data['scaler'] = scaler 517 | return data 518 | 519 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnzhan/TraverseNet/ba4ce7478386cb478293f5283a94c40bacdec0cc/layers/__init__.py -------------------------------------------------------------------------------- /layers/dcrnn_cell.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from utils.process import * 5 | 6 | class LayerParams: 7 | def __init__(self, rnn_network: torch.nn.Module, layer_type: str): 8 | self._rnn_network = rnn_network 9 | self._params_dict = {} 10 | self._biases_dict = {} 11 | self._type = layer_type 12 | 13 | def get_weights(self, shape, device): 14 | if shape not in self._params_dict: 15 | nn_param = torch.nn.Parameter(torch.empty(*shape, device=device)) 16 | torch.nn.init.xavier_normal_(nn_param) 17 | self._params_dict[shape] = nn_param 18 | self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)), 19 | nn_param) 20 | return self._params_dict[shape] 21 | 22 | def get_biases(self, length, device, bias_start=0.0): 23 | if length not in self._biases_dict: 24 | biases = torch.nn.Parameter(torch.empty(length, device=device)) 25 | torch.nn.init.constant_(biases, bias_start) 26 | self._biases_dict[length] = biases 27 | self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)), 28 | biases) 29 | 30 | return self._biases_dict[length] 31 | 32 | 33 | class DCGRUCell(torch.nn.Module): 34 | def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, device, nonlinearity='tanh', 35 | filter_type="laplacian", use_gc_for_ru=True): 36 | """ 37 | :param num_units: 38 | :param adj_mx: 39 | :param max_diffusion_step: 40 | :param num_nodes: 41 | :param nonlinearity: 42 | :param filter_type: "laplacian", "random_walk", "dual_random_walk". 43 | :param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates. 44 | """ 45 | 46 | super().__init__() 47 | self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu 48 | # support other nonlinearities up here? 49 | self._num_nodes = num_nodes 50 | self._num_units = num_units 51 | self._max_diffusion_step = max_diffusion_step 52 | self._supports = [] 53 | self._use_gc_for_ru = use_gc_for_ru 54 | supports = [] 55 | if filter_type == "laplacian": 56 | supports.append(calculate_scaled_laplacian(adj_mx, lambda_max=None)) 57 | elif filter_type == "random_walk": 58 | supports.append(calculate_random_walk_matrix(adj_mx).T) 59 | elif filter_type == "dual_random_walk": 60 | supports.append(calculate_random_walk_matrix(adj_mx).T) 61 | supports.append(calculate_random_walk_matrix(adj_mx.T).T) 62 | else: 63 | supports.append(calculate_scaled_laplacian(adj_mx)) 64 | for support in supports: 65 | self._supports.append(self._build_sparse_matrix(support, device)) 66 | 67 | self.device = device 68 | self._fc_params = LayerParams(self, 'fc') 69 | self._gconv_params = LayerParams(self, 'gconv') 70 | 71 | @staticmethod 72 | def _build_sparse_matrix(L,device): 73 | L = L.tocoo() 74 | indices = np.column_stack((L.row, L.col)) 75 | # this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L) 76 | indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))] 77 | L = torch.sparse_coo_tensor(indices.T, L.data, L.shape,dtype=torch.float32, device=device) 78 | return L 79 | 80 | def forward(self, inputs, hx): 81 | """Gated recurrent unit (GRU) with Graph Convolution. 82 | :param inputs: (B, num_nodes * input_dim) 83 | :param hx: (B, num_nodes * rnn_units) 84 | :return 85 | - Output: A `2-D` tensor with shape `(B, num_nodes * rnn_units)`. 86 | """ 87 | output_size = 2 * self._num_units 88 | if self._use_gc_for_ru: 89 | fn = self._gconv 90 | else: 91 | fn = self._fc 92 | value = torch.sigmoid(fn(inputs, hx, output_size, bias_start=1.0)) 93 | value = torch.reshape(value, (-1, self._num_nodes, output_size)) 94 | r, u = torch.split(tensor=value, split_size_or_sections=self._num_units, dim=-1) 95 | r = torch.reshape(r, (-1, self._num_nodes * self._num_units)) 96 | u = torch.reshape(u, (-1, self._num_nodes * self._num_units)) 97 | 98 | c = self._gconv(inputs, r * hx, self._num_units) 99 | if self._activation is not None: 100 | c = self._activation(c) 101 | 102 | new_state = u * hx + (1.0 - u) * c 103 | return new_state 104 | 105 | @staticmethod 106 | def _concat(x, x_): 107 | x_ = x_.unsqueeze(0) 108 | return torch.cat([x, x_], dim=0) 109 | 110 | def _fc(self, inputs, state, output_size, bias_start=0.0): 111 | batch_size = inputs.shape[0] 112 | inputs = torch.reshape(inputs, (batch_size * self._num_nodes, -1)) 113 | state = torch.reshape(state, (batch_size * self._num_nodes, -1)) 114 | inputs_and_state = torch.cat([inputs, state], dim=-1) 115 | input_size = inputs_and_state.shape[-1] 116 | weights = self._fc_params.get_weights((input_size, output_size)) 117 | value = torch.sigmoid(torch.matmul(inputs_and_state, weights)) 118 | biases = self._fc_params.get_biases(output_size, bias_start) 119 | value += biases 120 | return value 121 | 122 | def _gconv(self, inputs, state, output_size, bias_start=0.0): 123 | # Reshape input and state to (batch_size, num_nodes, input_dim/state_dim) 124 | batch_size = inputs.shape[0] 125 | inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1)) 126 | state = torch.reshape(state, (batch_size, self._num_nodes, -1)) 127 | inputs_and_state = torch.cat([inputs, state], dim=2) 128 | input_size = inputs_and_state.size(2) 129 | 130 | x = inputs_and_state 131 | x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size) 132 | x0 = torch.reshape(x0, shape=[self._num_nodes, input_size * batch_size]) 133 | x = torch.unsqueeze(x0, 0) 134 | 135 | if self._max_diffusion_step == 0: 136 | pass 137 | else: 138 | for support in self._supports: 139 | x1 = torch.sparse.mm(support, x0) 140 | x = self._concat(x, x1) 141 | 142 | for k in range(2, self._max_diffusion_step + 1): 143 | x2 = 2 * torch.sparse.mm(support, x1) - x0 144 | x = self._concat(x, x2) 145 | x1, x0 = x2, x1 146 | 147 | num_matrices = len(self._supports) * self._max_diffusion_step + 1 # Adds for x itself. 148 | x = torch.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size]) 149 | x = x.permute(3, 1, 2, 0) # (batch_size, num_nodes, input_size, order) 150 | x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices]) 151 | 152 | weights = self._gconv_params.get_weights((input_size * num_matrices, output_size),self.device) 153 | x = torch.matmul(x, weights) # (batch_size * self._num_nodes, output_size) 154 | 155 | biases = self._gconv_params.get_biases(output_size, self.device, bias_start) 156 | x += biases 157 | # Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim) 158 | return torch.reshape(x, [batch_size, self._num_nodes * output_size]) -------------------------------------------------------------------------------- /layers/gat_layer.py: -------------------------------------------------------------------------------- 1 | """Torch modules for graph attention networks(GAT).""" 2 | # pylint: disable= no-member, arguments-differ, invalid-name 3 | import torch as th 4 | from torch import nn 5 | 6 | from dgl import function as fn 7 | from dgl.ops import edge_softmax 8 | from dgl.base import DGLError 9 | from dgl.nn.pytorch.utils import Identity 10 | from dgl.utils import expand_as_pair 11 | import dgl 12 | import pickle 13 | import numpy as np 14 | 15 | # pylint: enable=W0235 16 | class GATConvs(nn.Module): 17 | r""" 18 | Description 19 | ----------- 20 | Apply `Graph Attention Network `__ 21 | over an input signal. 22 | .. math:: 23 | h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)} 24 | where :math:`\alpha_{ij}` is the attention score bewteen node :math:`i` and 25 | node :math:`j`: 26 | .. math:: 27 | \alpha_{ij}^{l} &= \mathrm{softmax_i} (e_{ij}^{l}) 28 | e_{ij}^{l} &= \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i} \| W h_{j}]\right) 29 | Parameters 30 | ---------- 31 | in_feats : int, or pair of ints 32 | Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`. 33 | ATConv can be applied on homogeneous graph and unidirectional 34 | `bipartite graph `__. 35 | If the layer is to be applied to a unidirectional bipartite graph, ``in_feats`` 36 | specifies the input feature size on both the source and destination nodes. If 37 | a scalar is given, the source and destination node feature size would take the 38 | same value. 39 | out_feats : int 40 | Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`. 41 | num_heads : int 42 | Number of heads in Multi-Head Attention. 43 | feat_drop : float, optional 44 | Dropout rate on feature. Defaults: ``0``. 45 | attn_drop : float, optional 46 | Dropout rate on attention weight. Defaults: ``0``. 47 | negative_slope : float, optional 48 | LeakyReLU angle of negative slope. Defaults: ``0.2``. 49 | residual : bool, optional 50 | If True, use residual connection. Defaults: ``False``. 51 | activation : callable activation function/layer or None, optional. 52 | If not None, applies an activation function to the updated node features. 53 | Default: ``None``. 54 | allow_zero_in_degree : bool, optional 55 | If there are 0-in-degree nodes in the graph, output for those nodes will be invalid 56 | since no message will be passed to those nodes. This is harmful for some applications 57 | causing silent performance regression. This module will raise a DGLError if it detects 58 | 0-in-degree nodes in input graph. By setting ``True``, it will suppress the check 59 | and let the users handle it by themselves. Defaults: ``False``. 60 | Note 61 | ---- 62 | Zero in-degree nodes will lead to invalid output value. This is because no message 63 | will be passed to those nodes, the aggregation function will be appied on empty input. 64 | A common practice to avoid this is to add a self-loop for each node in the graph if 65 | it is homogeneous, which can be achieved by: 66 | >>> g = ... # a DGLGraph 67 | >>> g = dgl.add_self_loop(g) 68 | Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph 69 | since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree`` 70 | to ``True`` for those cases to unblock the code and handle zere-in-degree nodes manually. 71 | A common practise to handle this is to filter out the nodes with zere-in-degree when use 72 | after conv. 73 | Examples 74 | -------- 75 | >>> import dgl 76 | >>> import numpy as np 77 | >>> import torch as th 78 | >>> from dgl.nn import GATConv 79 | >>> # Case 1: Homogeneous graph 80 | >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) 81 | >>> g = dgl.add_self_loop(g) 82 | >>> feat = th.ones(6, 10) 83 | >>> gatconv = GATConv(10, 2, num_heads=3) 84 | >>> res = gatconv(g, feat) 85 | >>> res 86 | tensor([[[ 3.4570, 1.8634], 87 | [ 1.3805, -0.0762], 88 | [ 1.0390, -1.1479]], 89 | [[ 3.4570, 1.8634], 90 | [ 1.3805, -0.0762], 91 | [ 1.0390, -1.1479]], 92 | [[ 3.4570, 1.8634], 93 | [ 1.3805, -0.0762], 94 | [ 1.0390, -1.1479]], 95 | [[ 3.4570, 1.8634], 96 | [ 1.3805, -0.0762], 97 | [ 1.0390, -1.1479]], 98 | [[ 3.4570, 1.8634], 99 | [ 1.3805, -0.0762], 100 | [ 1.0390, -1.1479]], 101 | [[ 3.4570, 1.8634], 102 | [ 1.3805, -0.0762], 103 | [ 1.0390, -1.1479]]], grad_fn=) 104 | >>> # Case 2: Unidirectional bipartite graph 105 | >>> u = [0, 1, 0, 0, 1] 106 | >>> v = [0, 1, 2, 3, 2] 107 | >>> g = dgl.bipartite((u, v)) 108 | >>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32)) 109 | >>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32)) 110 | >>> gatconv = GATConv((5,10), 2, 3) 111 | >>> res = gatconv(g, (u_feat, v_feat)) 112 | >>> res 113 | tensor([[[-0.6066, 1.0268], 114 | [-0.5945, -0.4801], 115 | [ 0.1594, 0.3825]], 116 | [[ 0.0268, 1.0783], 117 | [ 0.5041, -1.3025], 118 | [ 0.6568, 0.7048]], 119 | [[-0.2688, 1.0543], 120 | [-0.0315, -0.9016], 121 | [ 0.3943, 0.5347]], 122 | [[-0.6066, 1.0268], 123 | [-0.5945, -0.4801], 124 | [ 0.1594, 0.3825]]], grad_fn=) 125 | """ 126 | def __init__(self, 127 | in_feats, 128 | out_feats, 129 | num_heads, 130 | num_nodes, 131 | layerid, 132 | feat_drop=0., 133 | attn_drop=0., 134 | negative_slope=0.2, 135 | residual=False, 136 | activation=None, 137 | allow_zero_in_degree=False,fix=False): 138 | super(GATConvs, self).__init__() 139 | self._num_heads = num_heads 140 | self.num_nodes = num_nodes 141 | self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) 142 | self._out_feats = out_feats 143 | self._allow_zero_in_degree = allow_zero_in_degree 144 | self.layerid = layerid 145 | if isinstance(in_feats, tuple): 146 | self.fc_src = nn.Linear( 147 | self._in_src_feats, out_feats * num_heads, bias=False) 148 | self.fc_dst = nn.Linear( 149 | self._in_dst_feats, out_feats * num_heads, bias=False) 150 | else: 151 | self.fc1 = nn.Linear( 152 | self._in_src_feats, out_feats * num_heads, bias=False) 153 | self.fc2 = nn.Linear( 154 | self._in_src_feats, out_feats * num_heads, bias=False) 155 | self.fc3 = nn.Linear( 156 | self._in_src_feats, out_feats * num_heads, bias=False) 157 | self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) 158 | self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) 159 | self.feat_drop = nn.Dropout(feat_drop) 160 | self.attn_drop = nn.Dropout(attn_drop) 161 | self.leaky_relu = nn.LeakyReLU(negative_slope) 162 | if residual: 163 | if self._in_dst_feats != out_feats: 164 | self.res_fc = nn.Linear( 165 | self._in_dst_feats, num_heads * out_feats, bias=False) 166 | else: 167 | self.res_fc = Identity() 168 | else: 169 | self.register_buffer('res_fc', None) 170 | self.reset_parameters() 171 | self.activation = activation 172 | self.fix = fix 173 | 174 | def reset_parameters(self): 175 | """ 176 | Description 177 | ----------- 178 | Reinitialize learnable parameters. 179 | Note 180 | ---- 181 | The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization. 182 | The attention weights are using xavier initialization method. 183 | """ 184 | gain = nn.init.calculate_gain('relu') 185 | if hasattr(self, 'fc1'): 186 | nn.init.xavier_normal_(self.fc1.weight, gain=gain) 187 | nn.init.xavier_normal_(self.fc2.weight, gain=gain) 188 | nn.init.xavier_normal_(self.fc3.weight, gain=gain) 189 | 190 | else: 191 | nn.init.xavier_normal_(self.fc_src.weight, gain=gain) 192 | nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) 193 | nn.init.xavier_normal_(self.attn_l, gain=gain) 194 | nn.init.xavier_normal_(self.attn_r, gain=gain) 195 | if isinstance(self.res_fc, nn.Linear): 196 | nn.init.xavier_normal_(self.res_fc.weight, gain=gain) 197 | 198 | def set_allow_zero_in_degree(self, set_value): 199 | r""" 200 | Description 201 | ----------- 202 | Set allow_zero_in_degree flag. 203 | Parameters 204 | ---------- 205 | set_value : bool 206 | The value to be set to the flag. 207 | """ 208 | self._allow_zero_in_degree = set_value 209 | 210 | def forward(self, graph, feat): 211 | r""" 212 | Description 213 | ----------- 214 | Compute graph attention network layer. 215 | Parameters 216 | ---------- 217 | graph : DGLGraph 218 | The graph. 219 | feat : torch.Tensor or pair of torch.Tensor 220 | If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where 221 | :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. 222 | If a pair of torch.Tensor is given, the pair must contain two tensors of shape 223 | :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. 224 | Returns 225 | ------- 226 | torch.Tensor 227 | The output feature of shape :math:`(N, H, D_{out})` where :math:`H` 228 | is the number of heads, and :math:`D_{out}` is size of output feature. 229 | Raises 230 | ------ 231 | DGLError 232 | If there are 0-in-degree nodes in the input graph, it will raise DGLError 233 | since no message will be passed to those nodes. This will cause invalid output. 234 | The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``. 235 | """ 236 | with graph.local_scope(): 237 | if not self._allow_zero_in_degree: 238 | if (graph.in_degrees() == 0).any(): 239 | raise DGLError('There are 0-in-degree nodes in the graph, ' 240 | 'output for those nodes will be invalid. ' 241 | 'This is harmful for some applications, ' 242 | 'causing silent performance regression. ' 243 | 'Adding self-loop on the input graph by ' 244 | 'calling `g = dgl.add_self_loop(g)` will resolve ' 245 | 'the issue. Setting ``allow_zero_in_degree`` ' 246 | 'to be `True` when constructing this module will ' 247 | 'suppress the check and let the code run.') 248 | 249 | if isinstance(feat, tuple): 250 | h_src = self.feat_drop(feat[0]) 251 | h_dst = self.feat_drop(feat[1]) 252 | if not hasattr(self, 'fc_src'): 253 | self.fc_src, self.fc_src1, self.fc_dst = self.fc1, self.fc1, self.fc1 254 | shape = h_src.shape 255 | feat_src = self.fc_src(h_src).view(*shape[:-1], self._num_heads, self._out_feats) 256 | feat_src1 = self.fc_src(h_src).view(*shape[:-1], self._num_heads, self._out_feats) 257 | feat_dst = self.fc_src(h_dst).view(*shape[:-1], self._num_heads, self._out_feats) 258 | else: 259 | h_src = h_dst = self.feat_drop(feat) 260 | shape = h_src.shape 261 | feat_src = self.fc1(h_src).view( 262 | *shape[:-1], self._num_heads, self._out_feats) 263 | feat_src1 = self.fc1(h_src).view( 264 | *shape[:-1], self._num_heads, self._out_feats) 265 | feat_dst = self.fc1(h_src).view( 266 | *shape[:-1], self._num_heads, self._out_feats) 267 | if graph.is_block: 268 | feat_dst = feat_src[:graph.number_of_dst_nodes()] 269 | # NOTE: GAT paper uses "first concatenation then linear projection" 270 | # to compute attention scores, while ours is "first projection then 271 | # addition", the two approaches are mathematically equivalent: 272 | # We decompose the weight vector a mentioned in the paper into 273 | # [a_l || a_r], then 274 | # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j 275 | # Our implementation is much efficient because we do not need to 276 | # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, 277 | # addition could be optimized with DGL's built-in function u_add_v, 278 | # which further speeds up computation and saves memory footprint. 279 | el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1) 280 | er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) 281 | graph.srcdata.update({'ft': feat_src1, 'el': el}) 282 | graph.dstdata.update({'er': er}) 283 | 284 | # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. 285 | graph.apply_edges(fn.u_add_v('el', 'er', 'e')) 286 | 287 | e = self.leaky_relu(graph.edata.pop('e')) 288 | if self.fix: 289 | e = th.ones_like(e).to(e.device) 290 | # compute softmax 291 | #graph.edata['a'] = self.attn_drop(edge_weights) 292 | 293 | #graph = graph.reverse(share_ndata=True, share_edata=True) 294 | graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) 295 | 296 | #graph = graph.reverse(share_ndata=True, share_edata=True) 297 | # if not self.training: 298 | # sn, tg = graph.edges() 299 | # data = dict() 300 | # data['a'] = graph.edata['a'].cpu() 301 | # data['src'] = sn.cpu() 302 | # data['tgt'] = tg.cpu() 303 | # file = open('./att_l' + str(self.layerid) + "_" + str(int(np.random.random() * 10000)) + '.pkl', "wb") 304 | # pickle.dump(data, file) 305 | # file.close() 306 | 307 | # message passing 308 | graph.update_all(fn.u_mul_e('ft', 'a', 'm'), 309 | fn.sum('m', 'ft')) 310 | 311 | rst = graph.dstdata['ft'] 312 | # residual 313 | #if self.res_fc is not None: 314 | #resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats) 315 | # rst = rst + resval 316 | # activation 317 | if self.activation: 318 | rst = self.activation(rst) 319 | return rst 320 | -------------------------------------------------------------------------------- /layers/layernorm.py: -------------------------------------------------------------------------------- 1 | from torch.nn import init 2 | import numbers 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | class LayerNorm(nn.Module): 8 | __constants__ = ['normalized_shape', 'weight', 'bias', 'eps', 'elementwise_affine'] 9 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): 10 | super(LayerNorm, self).__init__() 11 | if isinstance(normalized_shape, numbers.Integral): 12 | normalized_shape = (normalized_shape,) 13 | self.normalized_shape = tuple(normalized_shape) 14 | self.eps = eps 15 | self.elementwise_affine = elementwise_affine 16 | if self.elementwise_affine: 17 | self.weight = nn.Parameter(torch.Tensor(*normalized_shape)) 18 | self.bias = nn.Parameter(torch.Tensor(*normalized_shape)) 19 | else: 20 | self.register_parameter('weight', None) 21 | self.register_parameter('bias', None) 22 | self.reset_parameters() 23 | 24 | 25 | def reset_parameters(self): 26 | if self.elementwise_affine: 27 | init.ones_(self.weight) 28 | init.zeros_(self.bias) 29 | 30 | def forward(self, input, idx): 31 | if self.elementwise_affine: 32 | return F.layer_norm(input, tuple(input.shape[1:]), self.weight[:,idx,:], self.bias[:,idx,:], self.eps) 33 | else: 34 | return F.layer_norm(input, tuple(input.shape[1:]), self.weight, self.bias, self.eps) 35 | 36 | def extra_repr(self): 37 | return '{normalized_shape}, eps={eps}, ' \ 38 | 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) 39 | -------------------------------------------------------------------------------- /layers/mlp_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, inputs, hiddens, act, out_act): 8 | """ 9 | Just an MLP. 10 | Parameters 11 | ---------- 12 | inputs : int 13 | Input size. 14 | hiddens : list[int] 15 | List of hidden units of each dense layer. 16 | act : callable 17 | Activation function. 18 | out_act : bool 19 | Whether to apply activation on the output of the last dense layer. 20 | """ 21 | super().__init__() 22 | 23 | self.W = nn.ModuleList() 24 | self.act = act 25 | self.out_act = out_act 26 | for i in range(len(hiddens)): 27 | in_dims = inputs if i == 0 else hiddens[i - 1] 28 | out_dims = hiddens[i] 29 | self.W.append(nn.Linear(in_dims, out_dims)) 30 | 31 | def forward(self, x): 32 | for i, W in enumerate(self.W): 33 | x = W(x) 34 | if i != len(self.W) - 1 or self.out_act: 35 | x = self.act(x) 36 | return x 37 | -------------------------------------------------------------------------------- /layers/smt.py: -------------------------------------------------------------------------------- 1 | """Heterograph NN modules""" 2 | import torch as th 3 | import torch.nn as nn 4 | 5 | class HeteroGraphConv(nn.Module): 6 | #this is the proposed message traverse layer. 7 | def __init__(self, mods, dim, num_rel): 8 | super(HeteroGraphConv, self).__init__() 9 | self.mods = nn.ModuleDict(mods) 10 | # Do not break if graph has 0-in-degree nodes. 11 | # Because there is no general rule to add self-loop for heterograph. 12 | for _, v in self.mods.items(): 13 | set_allow_zero_in_degree_fn = getattr(v, 'set_allow_zero_in_degree', None) 14 | if callable(set_allow_zero_in_degree_fn): 15 | set_allow_zero_in_degree_fn(True) 16 | self.kn = th.nn.Linear(dim, dim, bias=False) 17 | self.qn = th.nn.Linear(dim, dim, bias=False) 18 | self.vn = th.nn.Linear(dim, dim, bias=False) 19 | 20 | self.w1 = nn.Parameter(th.FloatTensor(size=(1, dim))) 21 | self.w2 = nn.Parameter(th.FloatTensor(size=(1, dim))) 22 | gain = nn.init.calculate_gain('relu') 23 | nn.init.xavier_normal_(self.w1, gain=gain) 24 | nn.init.xavier_normal_(self.w2, gain=gain) 25 | self.num_rel = num_rel 26 | 27 | def agg_fn(self, tensors, dsttype): 28 | kv = [] 29 | qv = [] 30 | vv = [] 31 | for i in range(self.num_rel): 32 | kv.append(self.kn(tensors[i])) 33 | qv.append(self.qn(tensors[i])) 34 | vv.append(self.vn(tensors[i])) 35 | alpha = [] 36 | for i in range(self.num_rel): 37 | #p = (kv[i] / (th.sum(kv[i]**2, dim=3)**0.5 + 1e-9).unsqueeze(dim=3)) * (qv[0]/(th.sum(qv[0]**2, dim=3)**0.5+1e-9).unsqueeze(dim=3)) 38 | p = kv[i] * self.w1 + qv[0] * self.w2 39 | #p = (kv[i]) * (qv[0]) / 64 40 | p = p.sum(dim=3) 41 | alpha.append(p) 42 | alpha = th.cat(alpha, dim=2) 43 | mask = th.zeros_like(alpha) - float("Inf") 44 | mask = th.where(alpha==0, mask, th.zeros_like(mask)) 45 | alpha = alpha + mask 46 | alpha = th.softmax(alpha, dim=2) 47 | tensors = th.cat(vv, dim=2) 48 | tensors = tensors.permute(3, 0, 1, 2) 49 | out = tensors * alpha 50 | out = out.sum(dim=3) 51 | out = out.unsqueeze(dim=-1) 52 | out = out.permute(1, 2, 3, 0) 53 | return out 54 | 55 | def forward(self, g, inputs, mod_args=None, mod_kwargs=None): 56 | """Forward computation 57 | 58 | Invoke the forward function with each module and aggregate their results. 59 | 60 | Parameters 61 | ---------- 62 | g : DGLHeteroGraph 63 | Graph data. 64 | inputs : dict[str, Tensor] or pair of dict[str, Tensor] 65 | Input node features. 66 | mod_args : dict[str, tuple[any]], optional 67 | Extra positional arguments for the sub-modules. 68 | mod_kwargs : dict[str, dict[str, any]], optional 69 | Extra key-word arguments for the sub-modules. 70 | 71 | Returns 72 | ------- 73 | dict[str, Tensor] 74 | Output representations for every types of nodes. 75 | """ 76 | if mod_args is None: 77 | mod_args = {} 78 | if mod_kwargs is None: 79 | mod_kwargs = {} 80 | outputs = {nty : [] for nty in g.dsttypes} 81 | if isinstance(inputs, tuple) or g.is_block: 82 | if isinstance(inputs, tuple): 83 | src_inputs, dst_inputs = inputs 84 | else: 85 | src_inputs = inputs 86 | dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()} 87 | 88 | for stype, etype, dtype in g.canonical_etypes: 89 | rel_graph = g[stype, etype, dtype] 90 | if rel_graph.number_of_edges() == 0: 91 | continue 92 | if stype not in src_inputs or dtype not in dst_inputs: 93 | continue 94 | dstdata = self.mods[etype]( 95 | rel_graph, 96 | (src_inputs[stype], dst_inputs[dtype]), 97 | *mod_args.get(etype, ()), 98 | **mod_kwargs.get(etype, {})) 99 | outputs[dtype].append(dstdata) 100 | else: 101 | for stype, etype, dtype in g.canonical_etypes: 102 | rel_graph = g[stype, etype, dtype] 103 | if rel_graph.number_of_edges() == 0: 104 | continue 105 | if stype not in inputs: 106 | continue 107 | dstdata = self.mods[etype]( 108 | rel_graph, 109 | inputs[stype], 110 | *mod_args.get(etype, ()), 111 | **mod_kwargs.get(etype, {})) 112 | outputs[dtype].append(dstdata) 113 | rsts = {} 114 | for nty, alist in outputs.items(): 115 | if len(alist) != 0: 116 | rsts[nty] = self.agg_fn(alist, nty) 117 | return rsts 118 | 119 | 120 | def get_aggregate_fn(agg): 121 | """Internal function to get the aggregation function for node data 122 | generated from different relations. 123 | 124 | Parameters 125 | ---------- 126 | agg : str 127 | Method for aggregating node features generated by different relations. 128 | Allowed values are 'sum', 'max', 'min', 'mean', 'stack'. 129 | 130 | Returns 131 | ------- 132 | callable 133 | Aggregator function that takes a list of tensors to aggregate 134 | and returns one aggregated tensor. 135 | """ 136 | if agg == 'sum': 137 | fn = th.sum 138 | elif agg == 'max': 139 | fn = lambda inputs, dim: th.max(inputs, dim=dim)[0] 140 | elif agg == 'min': 141 | fn = lambda inputs, dim: th.min(inputs, dim=dim)[0] 142 | elif agg == 'mean': 143 | fn = th.mean 144 | elif agg == 'stack': 145 | fn = None # will not be called 146 | else: 147 | raise DGLError('Invalid cross type aggregator. Must be one of ' 148 | '"sum", "max", "min", "mean" or "stack". But got "%s"' % agg) 149 | if agg == 'stack': 150 | def stack_agg(inputs, dsttype): # pylint: disable=unused-argument 151 | if len(inputs) == 0: 152 | return None 153 | return th.stack(inputs, dim=1) 154 | return stack_agg 155 | else: 156 | def aggfn(inputs, dsttype): # pylint: disable=unused-argument 157 | if len(inputs) == 0: 158 | return None 159 | stacked = th.stack(inputs, dim=0) 160 | return fn(stacked, dim=0) 161 | return aggfn 162 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import argparse 3 | from dataset.data import * 4 | from utils.metrics import * 5 | from utils.process import * 6 | import os 7 | from trainer.ctrainer import CTrainer 8 | from trainer.rtrainer import RTrainer 9 | from nets.traverse_net import TraverseNet, TraverseNetst 10 | from nets.stgcn_net import STGCNnet 11 | from nets.graphwavenet import gwnet 12 | from nets.astgcn_net import ASTGCNnet 13 | from nets.dcrnn_net import DCRNNModel 14 | import pickle 15 | import dgl 16 | import json 17 | import random 18 | import torch.optim as optim 19 | from torch.optim.lr_scheduler import LambdaLR 20 | import numpy as np 21 | torch.set_num_threads(3) 22 | 23 | def str_to_bool(value): 24 | if isinstance(value, bool): 25 | return value 26 | if value.lower() in {'false', 'f', '0', 'no', 'n'}: 27 | return False 28 | elif value.lower() in {'true', 't', '1', 'yes', 'y'}: 29 | return True 30 | raise ValueError(f'{value} is not a valid boolean value') 31 | 32 | 33 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 34 | """ 35 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 36 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 37 | Args: 38 | optimizer (:class:`~torch.optim.Optimizer`): 39 | The optimizer for which to schedule the learning rate. 40 | num_warmup_steps (:obj:`int`): 41 | The number of steps for the warmup phase. 42 | num_training_steps (:obj:`int`): 43 | The total number of training steps. 44 | last_epoch (:obj:`int`, `optional`, defaults to -1): 45 | The index of the last epoch when resuming training. 46 | Return: 47 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 48 | """ 49 | 50 | def lr_lambda(current_step: int): 51 | if current_step < num_warmup_steps: 52 | return float(current_step) / float(max(1, num_warmup_steps)) 53 | return max( 54 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 55 | ) 56 | 57 | return LambdaLR(optimizer, lr_lambda, last_epoch) 58 | 59 | def gpu_setup(use_gpu, gpu_id): 60 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 61 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 62 | 63 | if torch.cuda.is_available() and use_gpu: 64 | print('cuda available with GPU:',torch.cuda.get_device_name(0)) 65 | device = torch.device("cuda") 66 | else: 67 | print('cuda not available') 68 | device = torch.device("cpu") 69 | return device 70 | 71 | 72 | def run(dataloader, device,params,net_params, adj_mx=None): 73 | scaler = dataloader['scaler'] 74 | if net_params['model']=='traversenet': 75 | file = open(params['graph_path'], "rb") 76 | graph = pickle.load(file) 77 | relkeys = graph.keys() 78 | print([t[1] for t in graph.keys()]) 79 | graph = dgl.heterograph(graph) 80 | graph = graph.to(device) 81 | # file = open('./data/randg/metr_ed1.pkl', "rb") 82 | # #file = open('./data/metr-Gstd.pkl', "rb") 83 | # ds = pickle.load(file) 84 | # for t in ds.keys(): 85 | # graph.edges[t].data['weight'] = ds[t] 86 | model = TraverseNet(net_params, graph, relkeys) 87 | optimizer = optim.Adam(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay']) 88 | num_training_steps = dataloader['train_loader'].num_batch*params['epochs'] 89 | num_warmup_steps = int(num_training_steps*0.1) 90 | print('num_training_step:', num_training_steps) 91 | #lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps) 92 | trainer = CTrainer(model, optimizer, masked_mae, dataloader, params, net_params['seq_out_len'], scaler, device) 93 | 94 | elif net_params['model']=='traversenet-ab': 95 | #traversenet-ab is a model setting in ablation study that interleaves temporal graphs with spatial graphs. 96 | file = open(params['graph_path'], "rb") 97 | graph = pickle.load(file) 98 | file.close() 99 | relkeys = graph.keys() 100 | print([t[1] for t in graph.keys()]) 101 | graph = dgl.heterograph(graph) 102 | graph = graph.to(device) 103 | 104 | file1 = open(params['graph_path1'], "rb") 105 | graph1 = pickle.load(file1) 106 | file1.close() 107 | relkeys1 = graph1.keys() 108 | print([t[1] for t in graph1.keys()]) 109 | graph1 = dgl.heterograph(graph1) 110 | graph1 = graph1.to(device) 111 | 112 | model = TraverseNetst(net_params, graph, graph1, relkeys, relkeys1) 113 | optimizer = optim.Adam(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay']) 114 | trainer = CTrainer(model, optimizer, masked_mae, dataloader, params, net_params['seq_out_len'], scaler, device) 115 | 116 | elif net_params['model']=='stgcn': 117 | adj_mx = sym_adj((adj_mx+adj_mx.transpose())/2) 118 | adj_mx = torch.Tensor(adj_mx.todense()).to(device) 119 | 120 | model = STGCNnet(net_params, adj_mx) 121 | optimizer = optim.Adam(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay']) 122 | trainer = CTrainer(model, optimizer, masked_mae, dataloader, params, net_params['seq_out_len'], scaler, device) 123 | 124 | elif net_params['model']=='graphwavenet': 125 | supports = [torch.Tensor(asym_adj(adj_mx).todense()).to(device),torch.Tensor(asym_adj(np.transpose(adj_mx)).todense()).to(device)] 126 | model = gwnet(net_params, device, supports) 127 | optimizer = optim.Adam(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay']) 128 | trainer = CTrainer(model, optimizer, masked_mae, dataloader, params, net_params['seq_out_len'], scaler, device) 129 | 130 | 131 | elif net_params['model']=='astgcn': 132 | L_tilde = scaled_Laplacian(adj_mx) 133 | cheb_polynomials = [torch.from_numpy(i).type(torch.FloatTensor).to(device) for i in 134 | cheb_polynomial(L_tilde, net_params['K'])] 135 | model = ASTGCNnet(cheb_polynomials, net_params, device) 136 | optimizer = optim.Adam(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay']) 137 | trainer = CTrainer(model, optimizer, masked_mae, dataloader, params, net_params['seq_out_len'], scaler, device) 138 | 139 | elif net_params['model']=='dcrnn': 140 | model = DCRNNModel(adj_mx, device, net_params) 141 | optimizer = optim.Adam(model.parameters(), lr=params['lr'], eps=params['epsilon']) 142 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 30, 40, 50], gamma=0.1) 143 | trainer = RTrainer(model, optimizer, lr_scheduler, masked_mae, dataloader, params, net_params, scaler, device) 144 | 145 | elif net_params['model']=='gru': 146 | #the GRU model is equivalent to a DCRNN model with identity graph adjacency matrix. 147 | adj_mx = np.eye(net_params['num_nodes']) 148 | model = DCRNNModel(adj_mx, device, net_params) 149 | optimizer = optim.Adam(model.parameters(), lr=params['lr'], eps=params['epsilon']) 150 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 30, 40, 50], gamma=0.1) 151 | trainer = RTrainer(model, optimizer, lr_scheduler, masked_mae, dataloader, params, net_params, scaler, device) 152 | 153 | else: 154 | print("model is not defined.") 155 | exit 156 | 157 | nParams = sum([p.nelement() for p in model.parameters()]) 158 | print('Number of model parameters is', nParams) 159 | # nParams = sum([p.nelement() for p in model.start_conv.parameters()]) 160 | # print('Number of model parameters for start conv is ', nParams) 161 | # nParams = sum([p.nelement() for p in model.transformer.parameters()]) 162 | # print('Number of model parameters for transformer is ', nParams) 163 | print("start training...",flush=True) 164 | his_loss, train_time, val_time = [], [], [] 165 | 166 | minl = 1e5 167 | 168 | for i in range(params['epochs']): 169 | train_loss,train_mape,train_rmse, traint = trainer.train_epoch() 170 | train_time.append(traint) 171 | 172 | valid_loss, valid_mape, valid_rmse, valt = trainer.val_epoch() 173 | val_time.append(valt) 174 | 175 | his_loss.append(valid_loss) 176 | log = 'Epoch: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}, Valid Loss: {:.4f}, Valid MAPE: {:.4f}, Valid RMSE: {:.4f}, Training Time: {:.4f}/epoch, Valid Time: {:.4f}/epoch' 177 | print(log.format(i, train_loss, train_mape, train_rmse, valid_loss, valid_mape, valid_rmse, traint, valt),flush=True) 178 | 179 | out_dir = params['out_dir'] 180 | if not os.path.exists(out_dir): 181 | os.makedirs(out_dir) 182 | torch.save(model.state_dict(), '{}.pkl'.format(out_dir + "/epoch_" + str(i))) 183 | 184 | if valid_loss(b,N,F)(F,T)->(b,N,T) 29 | 30 | rhs = torch.matmul(self.W3, x).transpose(-1, -2) # (F)(b,N,F,T)->(b,N,T)->(b,T,N) 31 | 32 | product = torch.matmul(lhs, rhs) # (b,N,T)(b,T,N) -> (B, N, N) 33 | 34 | S = torch.matmul(self.Vs, torch.sigmoid(product + self.bs)) # (N,N)(B, N, N)->(B,N,N) 35 | 36 | S_normalized = F.softmax(S, dim=1) 37 | 38 | return S_normalized 39 | 40 | 41 | class cheb_conv_withSAt(nn.Module): 42 | ''' 43 | K-order chebyshev graph convolution 44 | ''' 45 | 46 | def __init__(self, K, cheb_polynomials, in_channels, out_channels): 47 | ''' 48 | :param K: int 49 | :param in_channles: int, num of channels in the input sequence 50 | :param out_channels: int, num of channels in the output sequence 51 | ''' 52 | super(cheb_conv_withSAt, self).__init__() 53 | self.K = K 54 | self.cheb_polynomials = cheb_polynomials 55 | self.in_channels = in_channels 56 | self.out_channels = out_channels 57 | self.DEVICE = cheb_polynomials[0].device 58 | self.Theta = nn.ParameterList([nn.Parameter(torch.FloatTensor(in_channels, out_channels).to(self.DEVICE)) for _ in range(K)]) 59 | 60 | def forward(self, x, spatial_attention): 61 | ''' 62 | Chebyshev graph convolution operation 63 | :param x: (batch_size, N, F_in, T) 64 | :return: (batch_size, N, F_out, T) 65 | ''' 66 | 67 | batch_size, num_of_vertices, in_channels, num_of_timesteps = x.shape 68 | 69 | outputs = [] 70 | 71 | for time_step in range(num_of_timesteps): 72 | 73 | graph_signal = x[:, :, :, time_step] # (b, N, F_in) 74 | 75 | output = torch.zeros(batch_size, num_of_vertices, self.out_channels).to(self.DEVICE) # (b, N, F_out) 76 | 77 | for k in range(self.K): 78 | 79 | T_k = self.cheb_polynomials[k] # (N,N) 80 | 81 | T_k_with_at = T_k.mul(spatial_attention) # (N,N)*(N,N) = (N,N) 多行和为1, 按着列进行归一化 82 | 83 | theta_k = self.Theta[k] # (in_channel, out_channel) 84 | 85 | rhs = T_k_with_at.permute(0, 2, 1).matmul(graph_signal) # (N, N)(b, N, F_in) = (b, N, F_in) 因为是左乘,所以多行和为1变为多列和为1,即一行之和为1,进行左乘 86 | 87 | output = output + rhs.matmul(theta_k) # (b, N, F_in)(F_in, F_out) = (b, N, F_out) 88 | 89 | outputs.append(output.unsqueeze(-1)) # (b, N, F_out, 1) 90 | 91 | return F.relu(torch.cat(outputs, dim=-1)) # (b, N, F_out, T) 92 | 93 | 94 | class Temporal_Attention_layer(nn.Module): 95 | def __init__(self, DEVICE, in_channels, num_of_vertices, num_of_timesteps): 96 | super(Temporal_Attention_layer, self).__init__() 97 | self.U1 = nn.Parameter(torch.FloatTensor(num_of_vertices).to(DEVICE)) 98 | self.U2 = nn.Parameter(torch.FloatTensor(in_channels, num_of_vertices).to(DEVICE)) 99 | self.U3 = nn.Parameter(torch.FloatTensor(in_channels).to(DEVICE)) 100 | self.be = nn.Parameter(torch.FloatTensor(1, num_of_timesteps, num_of_timesteps).to(DEVICE)) 101 | self.Ve = nn.Parameter(torch.FloatTensor(num_of_timesteps, num_of_timesteps).to(DEVICE)) 102 | 103 | def forward(self, x): 104 | ''' 105 | :param x: (batch_size, N, F_in, T) 106 | :return: (B, T, T) 107 | ''' 108 | _, num_of_vertices, num_of_features, num_of_timesteps = x.shape 109 | 110 | lhs = torch.matmul(torch.matmul(x.permute(0, 3, 2, 1), self.U1), self.U2) 111 | # x:(B, N, F_in, T) -> (B, T, F_in, N) 112 | # (B, T, F_in, N)(N) -> (B,T,F_in) 113 | # (B,T,F_in)(F_in,N)->(B,T,N) 114 | 115 | rhs = torch.matmul(self.U3, x) # (F)(B,N,F,T)->(B, N, T) 116 | 117 | product = torch.matmul(lhs, rhs) # (B,T,N)(B,N,T)->(B,T,T) 118 | 119 | E = torch.matmul(self.Ve, torch.sigmoid(product + self.be)) # (B, T, T) 120 | 121 | E_normalized = F.softmax(E, dim=1) 122 | 123 | return E_normalized 124 | 125 | 126 | class cheb_conv(nn.Module): 127 | ''' 128 | K-order chebyshev graph convolution 129 | ''' 130 | 131 | def __init__(self, K, cheb_polynomials, in_channels, out_channels): 132 | ''' 133 | :param K: int 134 | :param in_channles: int, num of channels in the input sequence 135 | :param out_channels: int, num of channels in the output sequence 136 | ''' 137 | super(cheb_conv, self).__init__() 138 | self.K = K 139 | self.cheb_polynomials = cheb_polynomials 140 | self.in_channels = in_channels 141 | self.out_channels = out_channels 142 | self.DEVICE = cheb_polynomials[0].device 143 | self.Theta = nn.ParameterList([nn.Parameter(torch.FloatTensor(in_channels, out_channels).to(self.DEVICE)) for _ in range(K)]) 144 | 145 | def forward(self, x): 146 | ''' 147 | Chebyshev graph convolution operation 148 | :param x: (batch_size, N, F_in, T) 149 | :return: (batch_size, N, F_out, T) 150 | ''' 151 | 152 | batch_size, num_of_vertices, in_channels, num_of_timesteps = x.shape 153 | 154 | outputs = [] 155 | 156 | for time_step in range(num_of_timesteps): 157 | 158 | graph_signal = x[:, :, :, time_step] # (b, N, F_in) 159 | 160 | output = torch.zeros(batch_size, num_of_vertices, self.out_channels).to(self.DEVICE) # (b, N, F_out) 161 | 162 | for k in range(self.K): 163 | 164 | T_k = self.cheb_polynomials[k] # (N,N) 165 | 166 | theta_k = self.Theta[k] # (in_channel, out_channel) 167 | 168 | rhs = graph_signal.permute(0, 2, 1).matmul(T_k).permute(0, 2, 1) 169 | 170 | output = output + rhs.matmul(theta_k) 171 | 172 | outputs.append(output.unsqueeze(-1)) 173 | 174 | return F.relu(torch.cat(outputs, dim=-1)) 175 | 176 | 177 | class ASTGCN_block(nn.Module): 178 | 179 | def __init__(self, DEVICE, in_channels, K, nb_chev_filter, nb_time_filter, time_strides, cheb_polynomials, num_of_vertices, num_of_timesteps): 180 | super(ASTGCN_block, self).__init__() 181 | self.TAt = Temporal_Attention_layer(DEVICE, in_channels, num_of_vertices, num_of_timesteps) 182 | self.SAt = Spatial_Attention_layer(DEVICE, in_channels, num_of_vertices, num_of_timesteps) 183 | self.cheb_conv_SAt = cheb_conv_withSAt(K, cheb_polynomials, in_channels, nb_chev_filter) 184 | self.time_conv = nn.Conv2d(nb_chev_filter, nb_time_filter, kernel_size=(1, 3), stride=(1, time_strides), padding=(0, 1)) 185 | self.residual_conv = nn.Conv2d(in_channels, nb_time_filter, kernel_size=(1, 1), stride=(1, time_strides)) 186 | self.ln = nn.LayerNorm(nb_time_filter) #需要将channel放到最后一个维度上 187 | 188 | def forward(self, x): 189 | ''' 190 | :param x: (batch_size, N, F_in, T) 191 | :return: (batch_size, N, nb_time_filter, T) 192 | ''' 193 | batch_size, num_of_vertices, num_of_features, num_of_timesteps = x.shape 194 | 195 | # TAt 196 | temporal_At = self.TAt(x) # (b, T, T) 197 | 198 | x_TAt = torch.matmul(x.reshape(batch_size, -1, num_of_timesteps), temporal_At).reshape(batch_size, num_of_vertices, num_of_features, num_of_timesteps) 199 | 200 | # SAt 201 | spatial_At = self.SAt(x_TAt) 202 | 203 | # cheb gcn 204 | spatial_gcn = self.cheb_conv_SAt(x, spatial_At) # (b,N,F,T) 205 | # spatial_gcn = self.cheb_conv(x) 206 | 207 | # convolution along the time axis 208 | time_conv_output = self.time_conv(spatial_gcn.permute(0, 2, 1, 3)) # (b,N,F,T)->(b,F,N,T) 用(1,3)的卷积核去做->(b,F,N,T) 209 | 210 | # residual shortcut 211 | x_residual = self.residual_conv(x.permute(0, 2, 1, 3)) # (b,N,F,T)->(b,F,N,T) 用(1,1)的卷积核去做->(b,F,N,T) 212 | 213 | x_residual = self.ln(F.relu(x_residual + time_conv_output).permute(0, 3, 2, 1)).permute(0, 2, 3, 1) 214 | # (b,F,N,T)->(b,T,N,F) -ln-> (b,T,N,F)->(b,N,F,T) 215 | 216 | return x_residual 217 | -------------------------------------------------------------------------------- /module/dcrnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from layers.dcrnn_cell import DCGRUCell 6 | 7 | class Seq2SeqAttrs: 8 | def __init__(self, adj_mx, net_params): 9 | self.adj_mx = adj_mx 10 | self.max_diffusion_step = net_params['max_diffusion_step'] 11 | self.cl_decay_steps = net_params['cl_decay_steps'] 12 | self.filter_type = net_params['filter_type'] 13 | self.num_nodes = net_params['num_nodes'] 14 | self.num_rnn_layers = net_params['num_layers'] 15 | self.rnn_units = net_params['dim'] 16 | self.hidden_state_size = self.num_nodes * self.rnn_units 17 | 18 | 19 | class EncoderModel(nn.Module, Seq2SeqAttrs): 20 | def __init__(self, adj_mx, device, net_params): 21 | nn.Module.__init__(self) 22 | Seq2SeqAttrs.__init__(self, adj_mx, net_params) 23 | self.input_dim = net_params['in_dim'] 24 | self.seq_len = net_params['seq_in_len'] # for the encoder 25 | self.dcgru_layers = nn.ModuleList( 26 | [DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes, device, 27 | filter_type=self.filter_type) for _ in range(self.num_rnn_layers)]) 28 | 29 | def forward(self, inputs, hidden_state=None): 30 | """ 31 | Encoder forward pass. 32 | :param inputs: shape (batch_size, self.num_nodes * self.input_dim) 33 | :param hidden_state: (num_layers, batch_size, self.hidden_state_size) 34 | optional, zeros if not provided 35 | :return: output: # shape (batch_size, self.hidden_state_size) 36 | hidden_state # shape (num_layers, batch_size, self.hidden_state_size) 37 | (lower indices mean lower layers) 38 | """ 39 | batch_size, _ = inputs.size() 40 | if hidden_state is None: 41 | hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size), 42 | device=inputs.device) 43 | hidden_states = [] 44 | output = inputs 45 | for layer_num, dcgru_layer in enumerate(self.dcgru_layers): 46 | next_hidden_state = dcgru_layer(output, hidden_state[layer_num]) 47 | hidden_states.append(next_hidden_state) 48 | output = next_hidden_state 49 | 50 | return output, torch.stack(hidden_states) # runs in O(num_layers) so not too slow 51 | 52 | 53 | class DecoderModel(nn.Module, Seq2SeqAttrs): 54 | def __init__(self, adj_mx, device, net_params): 55 | nn.Module.__init__(self) 56 | Seq2SeqAttrs.__init__(self, adj_mx, net_params) 57 | self.output_dim = net_params['out_dim'] 58 | self.horizon = net_params['seq_out_len'] # for the decoder 59 | self.projection_layer = nn.Linear(self.rnn_units, self.output_dim) 60 | self.dcgru_layers = nn.ModuleList( 61 | [DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes, device, 62 | filter_type=self.filter_type) for _ in range(self.num_rnn_layers)]) 63 | 64 | def forward(self, inputs, hidden_state=None): 65 | """ 66 | Decoder forward pass. 67 | :param inputs: shape (batch_size, self.num_nodes * self.output_dim) 68 | :param hidden_state: (num_layers, batch_size, self.hidden_state_size) 69 | optional, zeros if not provided 70 | :return: output: # shape (batch_size, self.num_nodes * self.output_dim) 71 | hidden_state # shape (num_layers, batch_size, self.hidden_state_size) 72 | (lower indices mean lower layers) 73 | """ 74 | hidden_states = [] 75 | output = inputs 76 | for layer_num, dcgru_layer in enumerate(self.dcgru_layers): 77 | next_hidden_state = dcgru_layer(output, hidden_state[layer_num]) 78 | hidden_states.append(next_hidden_state) 79 | output = next_hidden_state 80 | 81 | projected = self.projection_layer(output.view(-1, self.rnn_units)) 82 | output = projected.view(-1, self.num_nodes * self.output_dim) 83 | 84 | return output, torch.stack(hidden_states) 85 | 86 | 87 | -------------------------------------------------------------------------------- /module/stgcn_block.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class TimeBlock(nn.Module): 8 | """ 9 | Neural network block that applies a temporal convolution to each node of 10 | a graph in isolation. 11 | """ 12 | 13 | def __init__(self, in_channels, out_channels, kernel_size=3): 14 | """ 15 | :param in_channels: Number of input features at each node in each time 16 | step. 17 | :param out_channels: Desired number of output channels at each node in 18 | each time step. 19 | :param kernel_size: Size of the 1D temporal kernel. 20 | """ 21 | super(TimeBlock, self).__init__() 22 | self.conv1 = nn.Conv2d(in_channels, out_channels, (1, kernel_size)) 23 | self.conv2 = nn.Conv2d(in_channels, out_channels, (1, kernel_size)) 24 | self.conv3 = nn.Conv2d(in_channels, out_channels, (1, kernel_size)) 25 | 26 | def forward(self, X): 27 | """ 28 | :param X: Input data of shape (batch_size, num_nodes, num_timesteps, 29 | num_features=in_channels) 30 | :return: Output data of shape (batch_size, num_nodes, 31 | num_timesteps_out, num_features_out=out_channels) 32 | """ 33 | # Convert into NCHW format for pytorch to perform convolutions. 34 | #X = X.permute(0, 3, 1, 2) 35 | temp = self.conv1(X) * torch.sigmoid(self.conv2(X)) 36 | out = F.relu(temp + self.conv3(X)) 37 | # Convert back from NCHW to NHWC 38 | #out = out.permute(0, 2, 3, 1) 39 | return out 40 | 41 | 42 | class STGCNBlock(nn.Module): 43 | """ 44 | Neural network block that applies a temporal convolution on each node in 45 | isolation, followed by a graph convolution, followed by another temporal 46 | convolution on each node. 47 | """ 48 | 49 | def __init__(self, in_channels, spatial_channels, out_channels, 50 | num_nodes): 51 | """ 52 | :param in_channels: Number of input features at each node in each time 53 | step. 54 | :param spatial_channels: Number of output channels of the graph 55 | convolutional, spatial sub-block. 56 | :param out_channels: Desired number of output features at each node in 57 | each time step. 58 | :param num_nodes: Number of nodes in the graph. 59 | """ 60 | super(STGCNBlock, self).__init__() 61 | self.temporal1 = TimeBlock(in_channels=in_channels, 62 | out_channels=out_channels) 63 | self.lint = nn.Conv2d(in_channels=out_channels, out_channels=spatial_channels, kernel_size=(1, 1), bias=False) 64 | self.temporal2 = TimeBlock(in_channels=spatial_channels, 65 | out_channels=out_channels) 66 | self.batch_norm = nn.BatchNorm2d(num_nodes) 67 | 68 | def forward(self, X, A_hat): 69 | """ 70 | :param X: Input data of shape (batch_size, num_features, num_nodes, num_timesteps, 71 | ). 72 | :param A_hat: Normalized adjacency matrix. 73 | :return: Output data of shape (batch_size, num_nodes, 74 | num_timesteps_out, num_features=out_channels). 75 | """ 76 | t = self.temporal1(X) 77 | lfs = torch.einsum("ij,bkjm->bkim", [A_hat, t]) 78 | # t2 = F.relu(torch.einsum("ijkl,lp->ijkp", [lfs, self.Theta1])) 79 | t2 = F.relu(self.lint(lfs)) 80 | t3 = self.temporal2(t2) 81 | t3 = t3.transpose(2,1) 82 | t3 = self.batch_norm(t3) 83 | t3 = t3.transpose(2,1) 84 | return t3 85 | -------------------------------------------------------------------------------- /module/traversebody.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import dgl.nn.pytorch as dglnn 3 | import torch 4 | import copy 5 | from torch import nn 6 | from layers.gat_layer import GATConvs 7 | from dgl.ops import edge_softmax 8 | from layers.smt import HeteroGraphConv 9 | def module_list(module, n): 10 | return nn.ModuleList([copy.deepcopy(module) for i in range(n)]) 11 | 12 | 13 | class Norm(nn.Module): 14 | def __init__(self, dim, eps=1e-6): 15 | super().__init__() 16 | 17 | self.size = dim 18 | # create two learnable parameters to calibrate normalisation 19 | self.alpha = nn.Parameter(torch.ones(self.size)) 20 | self.bias = nn.Parameter(torch.zeros(self.size)) 21 | self.eps = eps 22 | 23 | def forward(self, x): 24 | norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \ 25 | / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias 26 | return norm 27 | 28 | class Norm1(nn.Module): 29 | def __init__(self, dim): 30 | super().__init__() 31 | self.norm = torch.nn.BatchNorm2d((dim)) 32 | self.dim = dim 33 | def forward(self, x): 34 | sh = x.shape 35 | x = x.view(-1, self.dim, sh[1], sh[2]) 36 | nx = self.norm(x) 37 | nx = nx.view(-1,sh[1],sh[2]) 38 | return nx 39 | 40 | 41 | def my_agg_func(tensors, dsttype): 42 | return tensors[0] 43 | 44 | def my_agg_func1(tensors, dsttype): 45 | return tensors[-1] 46 | 47 | class EncoderLayer(nn.Module): 48 | def __init__(self, num_nodes, dim, heads, rel_keys, dropout, layer_id, bi=True): 49 | super().__init__() 50 | self.norm_1 = Norm1(num_nodes) 51 | self.norm_2 = Norm1(num_nodes) 52 | block = dict() 53 | if bi: 54 | self.attn1 = GATConvs(dim, dim // heads, heads, num_nodes,layer_id) 55 | self.attn2 = GATConvs(dim, dim // heads, heads, num_nodes,layer_id) 56 | 57 | for it in rel_keys: 58 | k = it[1] 59 | s = k.split('_') 60 | if len(s)!=1: 61 | if s[1] == '-1': 62 | block[k] = self.attn1 63 | else: 64 | block[k] = self.attn2 65 | else: 66 | if s[0] == '0': 67 | block[k] = self.attn1 68 | else: 69 | block[k] = self.attn2 70 | 71 | # if it[1]=='0': 72 | # block[it[1]] = self.attn1 73 | # elif it[1]=='1': 74 | # block[it[1]] = self.attn2 75 | # else: 76 | # block[it[1]] = self.attn3 77 | # k = it[1] 78 | # s = k.split('_') 79 | # if s[1] == '-1': 80 | # block[k] = GATConv(dim, dim // heads, heads, num_nodes) 81 | # elif s[1] == '0': 82 | # block[k] = self.attn1 83 | # else: 84 | # block[k] = self.attn2 85 | else: 86 | for it in rel_keys: 87 | k = it[1] 88 | block[k] = GATConvs(dim, dim // heads, heads, num_nodes) 89 | 90 | 91 | self.conv = HeteroGraphConv(block,dim,len(rel_keys)) 92 | #self.conv = dglnn.HeteroGraphConv(block,aggregate='mean') 93 | self.ff = FeedForward(dim) 94 | self.dropout_1 = nn.Dropout(dropout) 95 | self.dropout_2 = nn.Dropout(dropout) 96 | 97 | 98 | #self.tp = nn.Conv2d(in_channels=18, out_channels=1, kernel_size=(1, 1)) 99 | def forward(self, g, x): 100 | # x NT*batch_size*dim 101 | h = {'v': x['v']} 102 | h['v'] = self.norm_1(h['v']) 103 | h = self.conv(g, h) 104 | shape = h['v'].shape 105 | h['v'] = h['v'].reshape(*shape[:-2], -1) 106 | #h['v'] = self.tp(h['v']).squeeze() 107 | x['v'] = x['v'] + self.dropout_1(h['v']) 108 | h['v'] = self.norm_2(x['v']) 109 | h['v'] = x['v'] + self.dropout_2(self.ff(h['v'])) 110 | return h 111 | 112 | 113 | class Encoder(nn.Module): 114 | def __init__(self, num_nodes, dim, heads, relkeys, num_layer, dropout): 115 | super().__init__() 116 | self.num_nodes = num_nodes 117 | self.num_layer = num_layer 118 | self.layers = nn.ModuleList() 119 | for i in range(num_layer): 120 | self.layers.append(EncoderLayer(num_nodes, dim, heads, relkeys, dropout, i)) 121 | self.norm = Norm1(num_nodes) 122 | self.dim = dim 123 | 124 | 125 | def forward(self, g, x): 126 | for i in range(self.num_layer): 127 | x = self.layers[i](g, x) 128 | x['v'] = self.norm(x['v']) 129 | return x 130 | 131 | 132 | class Encoder1(nn.Module): 133 | def __init__(self, num_nodes, dim, heads, relkeys1, relkeys2, num_layer, dropout): 134 | super().__init__() 135 | self.num_nodes = num_nodes 136 | self.num_layer = num_layer 137 | self.layers1 = module_list(EncoderLayer(num_nodes, dim, heads, relkeys1, dropout, 2), num_layer) 138 | self.layers2 = module_list(EncoderLayer(num_nodes, dim, heads, relkeys2, dropout, 2), num_layer) 139 | self.convl = module_list(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(1, 3)),num_layer) 140 | self.norm = Norm1(num_nodes) 141 | self.dim = dim 142 | 143 | def forward(self, g1, g2, x): 144 | for i in range(self.num_layer): 145 | x = self.layers1[i](g1, x) 146 | x = self.layers2[i](g2, x) 147 | x['v'] = self.norm(x['v']) 148 | return x 149 | 150 | 151 | class FeedForward(nn.Module): 152 | def __init__(self, dim, d_ff=128, dropout=0.1): 153 | super().__init__() 154 | self.linear_1 = nn.Linear(dim, d_ff) 155 | self.dropout = nn.Dropout(dropout) 156 | self.linear_2 = nn.Linear(d_ff, dim) 157 | 158 | def forward(self, x): 159 | x = self.dropout(F.gelu(self.linear_1(x))) 160 | x = self.linear_2(x) 161 | return x 162 | 163 | 164 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnzhan/TraverseNet/ba4ce7478386cb478293f5283a94c40bacdec0cc/nets/__init__.py -------------------------------------------------------------------------------- /nets/astgcn_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from module.astgcn_block import ASTGCN_block 6 | 7 | 8 | class ASTGCNnet(nn.Module): 9 | 10 | def __init__(self, supports, net_params, DEVICE): 11 | ''' 12 | :param nb_block: 13 | :param in_channels: 14 | :param K: 15 | :param nb_chev_filter: 16 | :param nb_time_filter: 17 | :param time_strides: 18 | :param cheb_polynomials: 19 | :param nb_predict_step: 20 | ''' 21 | 22 | super(ASTGCNnet, self).__init__() 23 | 24 | in_dim = net_params['in_dim'] 25 | K = net_params['K'] 26 | nb_chev_filter = net_params['nb_chev_filter'] 27 | nb_time_filter = net_params['nb_time_filter'] 28 | time_strides = net_params['time_strides'] 29 | seq_out_len = net_params['seq_out_len'] 30 | seq_in_len = net_params['seq_in_len'] 31 | num_nodes = net_params['num_nodes'] 32 | nb_block = net_params['nb_block'] 33 | 34 | self.BlockList = nn.ModuleList([ASTGCN_block(DEVICE, in_dim, K, nb_chev_filter, nb_time_filter, time_strides, supports, num_nodes, seq_in_len)]) 35 | self.BlockList.extend([ASTGCN_block(DEVICE, nb_time_filter, K, nb_chev_filter, nb_time_filter, 1, supports, num_nodes, seq_in_len//time_strides) for _ in range(nb_block-1)]) 36 | 37 | self.final_conv = nn.Conv2d(int(seq_in_len/time_strides), seq_out_len, kernel_size=(1, nb_time_filter)) 38 | self.init_pars() 39 | 40 | def init_pars(self): 41 | for p in self.parameters(): 42 | if p.dim() > 1: 43 | nn.init.xavier_uniform_(p) 44 | else: 45 | nn.init.uniform_(p) 46 | 47 | def forward(self, x, dummy=None): 48 | ''' 49 | :param x: (B, N_nodes, F_in, T_in) 50 | :return: (B, N_nodes, T_out) 51 | ''' 52 | x = x.transpose(2,1) 53 | for block in self.BlockList: 54 | x = block(x) 55 | 56 | output = self.final_conv(x.permute(0, 3, 1, 2)) 57 | # (b,N,F,T)->(b,T,N,F)-conv<1,F>->(b,c_out*T,N,1)->(b,c_out*T,N)->(b,N,T) 58 | return output 59 | 60 | -------------------------------------------------------------------------------- /nets/dcrnn_net.py: -------------------------------------------------------------------------------- 1 | from module.dcrnn import * 2 | 3 | def count_parameters(model): 4 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 5 | 6 | 7 | class DCRNNModel(nn.Module, Seq2SeqAttrs): 8 | def __init__(self, adj_mx, device, net_params): 9 | super().__init__() 10 | Seq2SeqAttrs.__init__(self, adj_mx, net_params) 11 | self.encoder_model = EncoderModel(adj_mx, device, net_params) 12 | self.decoder_model = DecoderModel(adj_mx, device, net_params) 13 | self.cl_decay_steps = net_params['cl_decay_steps'] 14 | self.use_curriculum_learning = net_params['use_curriculum_learning'] 15 | self.net_params = net_params 16 | def _compute_sampling_threshold(self, batches_seen): 17 | return self.cl_decay_steps / ( 18 | self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps)) 19 | 20 | def encoder(self, inputs): 21 | """ 22 | encoder forward pass on t time steps 23 | :param inputs: shape (seq_len, batch_size, num_sensor * input_dim) 24 | :return: encoder_hidden_state: (num_layers, batch_size, self.hidden_state_size) 25 | """ 26 | encoder_hidden_state = None 27 | for t in range(self.encoder_model.seq_len): 28 | _, encoder_hidden_state = self.encoder_model(inputs[t], encoder_hidden_state) 29 | 30 | return encoder_hidden_state 31 | 32 | def decoder(self, encoder_hidden_state, labels=None, batches_seen=None): 33 | """ 34 | Decoder forward pass 35 | :param encoder_hidden_state: (num_layers, batch_size, self.hidden_state_size) 36 | :param labels: (self.horizon, batch_size, self.num_nodes * self.output_dim) [optional, not exist for inference] 37 | :param batches_seen: global step [optional, not exist for inference] 38 | :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) 39 | """ 40 | batch_size = encoder_hidden_state.size(1) 41 | go_symbol = torch.zeros((batch_size, self.num_nodes * self.decoder_model.output_dim), 42 | device=encoder_hidden_state.device) 43 | decoder_hidden_state = encoder_hidden_state 44 | decoder_input = go_symbol 45 | 46 | outputs = [] 47 | 48 | for t in range(self.decoder_model.horizon): 49 | decoder_output, decoder_hidden_state = self.decoder_model(decoder_input, 50 | decoder_hidden_state) 51 | decoder_input = decoder_output 52 | outputs.append(decoder_output) 53 | if self.training and self.use_curriculum_learning: 54 | c = np.random.uniform(0, 1) 55 | if c < self._compute_sampling_threshold(batches_seen): 56 | decoder_input = labels[t] 57 | outputs = torch.stack(outputs) 58 | return outputs 59 | 60 | def forward(self, inputs, labels=None, batches_seen=None): 61 | """ 62 | seq2seq forward pass 63 | :param inputs: shape (seq_len, batch_size, num_sensor * input_dim) 64 | :param labels: shape (horizon, batch_size, num_sensor * output) 65 | :param batches_seen: batches seen till now 66 | :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) 67 | """ 68 | labels = labels[...,0].view(self.net_params['seq_out_len'],-1, self.net_params['num_nodes']) 69 | encoder_hidden_state = self.encoder(inputs) 70 | outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen) 71 | if batches_seen == 0: 72 | print( 73 | "Total trainable parameters {}".format(count_parameters(self)) 74 | ) 75 | return outputs -------------------------------------------------------------------------------- /nets/graphwavenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import sys 6 | 7 | 8 | class nconv(nn.Module): 9 | def __init__(self): 10 | super(nconv,self).__init__() 11 | 12 | def forward(self,x, A): 13 | x = torch.einsum('ncvl,vw->ncwl',(x,A)) 14 | return x.contiguous() 15 | 16 | class linear(nn.Module): 17 | def __init__(self,c_in,c_out): 18 | super(linear,self).__init__() 19 | self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=True) 20 | 21 | def forward(self,x): 22 | return self.mlp(x) 23 | 24 | class gcn(nn.Module): 25 | def __init__(self,c_in,c_out,dropout,support_len=3,order=2): 26 | super(gcn,self).__init__() 27 | self.nconv = nconv() 28 | c_in = (order*support_len+1)*c_in 29 | self.mlp = linear(c_in,c_out) 30 | self.dropout = dropout 31 | self.order = order 32 | 33 | def forward(self,x,support): 34 | out = [x] 35 | for a in support: 36 | x1 = self.nconv(x,a) 37 | out.append(x1) 38 | for k in range(2, self.order + 1): 39 | x2 = self.nconv(x1,a) 40 | out.append(x2) 41 | x1 = x2 42 | 43 | h = torch.cat(out,dim=1) 44 | h = self.mlp(h) 45 | h = F.dropout(h, self.dropout, training=self.training) 46 | return h 47 | 48 | 49 | class gwnet(nn.Module): 50 | def __init__(self, net_params, device, supports): 51 | super(gwnet, self).__init__() 52 | num_nodes = net_params['num_nodes'] 53 | dropout = net_params['dropout'] 54 | gcn_bool = net_params['gcn_bool'] 55 | addaptadj = net_params['addaptadj'] 56 | in_dim = net_params['in_dim'] 57 | out_dim = net_params['seq_out_len'] 58 | residual_channels = net_params['residual_channels'] 59 | dilation_channels = net_params['dilation_channels'] 60 | skip_channels = net_params['skip_channels'] 61 | end_channels = net_params['end_channels'] 62 | kernel_size = net_params['kernel_size'] 63 | blocks = net_params['blocks'] 64 | layers = net_params['num_layers'] 65 | 66 | self.dropout = dropout 67 | self.blocks = blocks 68 | self.layers = layers 69 | self.gcn_bool = gcn_bool 70 | self.addaptadj = addaptadj 71 | 72 | self.filter_convs = nn.ModuleList() 73 | self.gate_convs = nn.ModuleList() 74 | self.residual_convs = nn.ModuleList() 75 | self.skip_convs = nn.ModuleList() 76 | self.bn = nn.ModuleList() 77 | self.gconv = nn.ModuleList() 78 | 79 | self.start_conv = nn.Conv2d(in_channels=in_dim, 80 | out_channels=residual_channels, 81 | kernel_size=(1,1)) 82 | self.supports = supports 83 | 84 | receptive_field = 1 85 | 86 | self.supports_len = 0 87 | if supports is not None: 88 | self.supports_len += len(supports) 89 | 90 | if gcn_bool and addaptadj: 91 | self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10).to(device), requires_grad=True).to(device) 92 | self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes).to(device), requires_grad=True).to(device) 93 | self.supports_len +=1 94 | 95 | for b in range(blocks): 96 | additional_scope = kernel_size - 1 97 | new_dilation = 1 98 | for i in range(layers): 99 | # dilated convolutions 100 | self.filter_convs.append(nn.Conv2d(in_channels=residual_channels, 101 | out_channels=dilation_channels, 102 | kernel_size=(1,kernel_size),dilation=new_dilation)) 103 | 104 | self.gate_convs.append(nn.Conv1d(in_channels=residual_channels, 105 | out_channels=dilation_channels, 106 | kernel_size=(1, kernel_size), dilation=new_dilation)) 107 | 108 | # 1x1 convolution for residual connection 109 | self.residual_convs.append(nn.Conv1d(in_channels=dilation_channels, 110 | out_channels=residual_channels, 111 | kernel_size=(1, 1))) 112 | 113 | # 1x1 convolution for skip connection 114 | self.skip_convs.append(nn.Conv1d(in_channels=dilation_channels, 115 | out_channels=skip_channels, 116 | kernel_size=(1, 1))) 117 | self.bn.append(nn.BatchNorm2d(residual_channels)) 118 | new_dilation *=2 119 | receptive_field += additional_scope 120 | additional_scope *= 2 121 | if self.gcn_bool: 122 | self.gconv.append(gcn(dilation_channels,residual_channels,dropout,support_len=self.supports_len)) 123 | 124 | 125 | 126 | self.end_conv_1 = nn.Conv2d(in_channels=skip_channels, 127 | out_channels=end_channels, 128 | kernel_size=(1,1), 129 | bias=True) 130 | 131 | self.end_conv_2 = nn.Conv2d(in_channels=end_channels, 132 | out_channels=out_dim, 133 | kernel_size=(1,1), 134 | bias=True) 135 | 136 | self.receptive_field = receptive_field 137 | 138 | 139 | 140 | def forward(self, input, dummy=None): 141 | in_len = input.size(3) 142 | if in_len dilate -|----| * ----|-- 1x1 -- + --> *input* 162 | # |-- conv -- sigm --| | 163 | # 1x1 164 | # | 165 | # ---------------------------------------> + -------------> *skip* 166 | 167 | #(dilation, init_dilation) = self.dilations[i] 168 | 169 | #residual = dilation_func(x, dilation, init_dilation, i) 170 | residual = x 171 | # dilated convolution 172 | filter = self.filter_convs[i](residual) 173 | filter = torch.tanh(filter) 174 | gate = self.gate_convs[i](residual) 175 | gate = torch.sigmoid(gate) 176 | x = filter * gate 177 | 178 | # parametrized skip connection 179 | 180 | s = x 181 | s = self.skip_convs[i](s) 182 | try: 183 | skip = skip[:, :, :, -s.size(3):] 184 | except: 185 | skip = 0 186 | skip = s + skip 187 | 188 | 189 | if self.gcn_bool and self.supports is not None: 190 | if self.addaptadj: 191 | x = self.gconv[i](x, new_supports) 192 | else: 193 | x = self.gconv[i](x,self.supports) 194 | else: 195 | x = self.residual_convs[i](x) 196 | 197 | x = x + residual[:, :, :, -x.size(3):] 198 | 199 | 200 | x = self.bn[i](x) 201 | 202 | x = F.relu(skip) 203 | x = F.relu(self.end_conv_1(x)) 204 | x = self.end_conv_2(x) 205 | return x -------------------------------------------------------------------------------- /nets/stgcn_net.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from module.stgcn_block import STGCNBlock,TimeBlock 3 | 4 | class STGCNnet(nn.Module): 5 | """ 6 | Spatio-temporal graph convolutional network as described in 7 | https://arxiv.org/abs/1709.04875v3 by Yu et al. 8 | Input should have shape (batch_size, num_nodes, num_input_time_steps, 9 | num_features). 10 | """ 11 | 12 | def __init__(self, net_params, adj): 13 | """ 14 | :param num_nodes: Number of nodes in the graph. 15 | :param num_features: Number of features at each node in each time step. 16 | :param num_timesteps_input: Number of past time steps fed into the 17 | network. 18 | :param num_timesteps_output: Desired number of future time steps 19 | output by the network. 20 | """ 21 | super(STGCNnet, self).__init__() 22 | 23 | num_nodes = net_params['num_nodes'] 24 | num_features = net_params['in_dim'] 25 | num_timesteps_input = net_params['seq_in_len'] 26 | num_timesteps_output = net_params['seq_out_len'] 27 | 28 | self.block1 = STGCNBlock(in_channels=num_features, out_channels=64, 29 | spatial_channels=32, num_nodes=num_nodes) 30 | self.block2 = STGCNBlock(in_channels=64, out_channels=128, 31 | spatial_channels=32, num_nodes=num_nodes) 32 | self.last_temporal = TimeBlock(in_channels=128, out_channels=64) 33 | self.fully = nn.Linear((num_timesteps_input - 2 * 5) * 64, 34 | num_timesteps_output) 35 | self.adj = adj 36 | def forward(self, X, dummy): 37 | """ 38 | :param X: Input data of shape (batch_size, num_nodes, num_timesteps, 39 | num_features=in_channels). 40 | :param A_hat: Normalized adjacency matrix. 41 | """ 42 | out1 = self.block1(X, self.adj) 43 | out2 = self.block2(out1, self.adj) 44 | out3 = self.last_temporal(out2) 45 | out3 = out3.transpose(2,1) 46 | out4 = self.fully(out3.reshape((out3.shape[0], out3.shape[1], -1))) 47 | out4 = out4.unsqueeze(dim=1) 48 | out4 = out4.transpose(3,1) 49 | return out4 50 | 51 | -------------------------------------------------------------------------------- /nets/traverse_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from module.traversebody import Encoder, Encoder1 6 | import numpy as np 7 | from torch.utils.checkpoint import checkpoint 8 | class Norm1(nn.Module): 9 | def __init__(self, dim): 10 | super().__init__() 11 | self.norm = torch.nn.BatchNorm2d(dim) 12 | self.dim = dim 13 | def forward(self, x): 14 | sh = x.shape 15 | x = x.transpose(2,1) 16 | x = x.view(-1, self.dim, sh[1], sh[2]) 17 | nx = self.norm(x) 18 | nx = nx.transpose(1,2) 19 | return nx 20 | 21 | class PrePreplayer(nn.Module): 22 | def __init__(self, in_dim, dim, num_nodes, seq_l, dropout): 23 | super().__init__() 24 | self.start_conv = nn.Conv2d(in_channels=in_dim, out_channels=dim, kernel_size=(1, 1)) 25 | self.norm1 = torch.nn.LayerNorm((dim,num_nodes,seq_l)) 26 | self.dropout = dropout 27 | 28 | def forward(self,x, dummy): 29 | h = self.start_conv(x) 30 | h = self.norm1(h) 31 | return h 32 | 33 | class PostPreplayer(nn.Module): 34 | def __init__(self, dim, out_dim, num_nodes, seq_l, dropout): 35 | super().__init__() 36 | self.norm1 = torch.nn.LayerNorm((dim,num_nodes,seq_l)) 37 | self.end_conv_1 = nn.Conv2d(in_channels=dim, out_channels=out_dim**2, kernel_size=(1, seq_l)) 38 | self.end_conv_2 = nn.Conv2d(in_channels=out_dim**2, out_channels=out_dim, kernel_size=(1, 1)) 39 | self.dim = dim 40 | self.seq_l = seq_l 41 | self.num_nodes = num_nodes 42 | self.dropout = dropout 43 | def forward(self, x): 44 | h = self.norm1(x) 45 | h = F.relu(self.end_conv_1(h)) 46 | h = self.end_conv_2(h) 47 | return h 48 | 49 | class TraverseNet(nn.Module): 50 | def __init__(self, net_params, g1, relkeys): 51 | super().__init__() 52 | self.start_conv = PrePreplayer(net_params['in_dim'], net_params['dim'], net_params['num_nodes'], net_params['seq_in_len'], net_params['dropout']) 53 | self.end_conv = PostPreplayer(net_params['dim'], net_params['seq_out_len'], net_params['num_nodes'], net_params['seq_in_len'], net_params['dropout']) 54 | self.transformer = Encoder(net_params['num_nodes'], net_params['dim'], net_params['heads'], relkeys, net_params['num_layers'], net_params['dropout']) 55 | 56 | self.in_dim = net_params['in_dim'] 57 | self.num_nodes = net_params['num_nodes'] 58 | self.seq_in_len = net_params['seq_in_len'] 59 | self.seq_out_len = net_params['seq_out_len'] 60 | 61 | self.dim = net_params['dim'] 62 | self.g1 = g1 63 | self.cl_decay_steps = net_params['cl_decay_steps'] 64 | self.num_layer = net_params['num_layers'] 65 | 66 | def _init_pos(self,sq,dim): 67 | enc = torch.Tensor(sq,dim) 68 | for t in range(sq): 69 | for i in range(0, dim, 2): 70 | enc[t, i] = math.sin(t / (10000 ** ((2 * i)/dim))) 71 | enc[t, i + 1] = math.cos(t / (10000 ** ((2 * (i + 1))/dim))) 72 | return enc 73 | def _compute_sampling_threshold(self, batches_seen): 74 | return self.cl_decay_steps / ( 75 | self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps)) 76 | 77 | def forward(self, src, dummy): 78 | x = src[:,:self.in_dim,:,-(self.seq_in_len):] 79 | h = self.start_conv(x, dummy) 80 | h = h.permute(3, 2, 0, 1) 81 | h = h.reshape(self.num_nodes * self.seq_in_len, -1, self.dim) 82 | hx = {'v': h} 83 | out = self.transformer(self.g1, hx) 84 | out = out['v'].reshape(self.seq_in_len, self.num_nodes, -1, self.dim) 85 | out = out.permute(2, 3, 1, 0) 86 | out = self.end_conv(out) 87 | return out 88 | 89 | #interleave spatial attentions with temporal attentions 90 | class TraverseNetst(nn.Module): 91 | def __init__(self, net_params, g1, g2, relkeys1, relkeys2): 92 | super().__init__() 93 | self.start_conv = PrePreplayer(net_params['in_dim'], net_params['dim'], net_params['num_nodes'], net_params['seq_in_len'], net_params['dropout']) 94 | self.end_conv = PostPreplayer(net_params['dim'], net_params['seq_out_len'], net_params['num_nodes'], net_params['seq_in_len'], net_params['dropout']) 95 | self.transformer = Encoder1(net_params['num_nodes'], net_params['dim'], net_params['heads'], relkeys1, relkeys2, net_params['num_layers'], net_params['dropout']) 96 | 97 | self.in_dim = net_params['in_dim'] 98 | self.num_nodes = net_params['num_nodes'] 99 | self.seq_in_len = net_params['seq_in_len'] 100 | self.seq_out_len = net_params['seq_out_len'] 101 | 102 | self.dim = net_params['dim'] 103 | self.g1 = g1 104 | self.g2 = g2 105 | self.cl_decay_steps = net_params['cl_decay_steps'] 106 | self.num_layer = net_params['num_layers'] 107 | 108 | def _init_pos(self,sq,dim): 109 | enc = torch.Tensor(sq,dim) 110 | for t in range(sq): 111 | for i in range(0, dim, 2): 112 | enc[t, i] = math.sin(t / (10000 ** ((2 * i)/dim))) 113 | enc[t, i + 1] = math.cos(t / (10000 ** ((2 * (i + 1))/dim))) 114 | return enc 115 | def _compute_sampling_threshold(self, batches_seen): 116 | return self.cl_decay_steps / ( 117 | self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps)) 118 | 119 | def forward(self, src, dummy): 120 | x = src[:,:self.in_dim,:,-(self.seq_in_len):] 121 | h = self.start_conv(x, dummy) 122 | h = h.permute(3, 2, 0, 1) 123 | h = h.reshape(self.num_nodes * self.seq_in_len, -1, self.dim) 124 | hx = {'v': h} 125 | out = self.transformer(self.g1, self.g2, hx) 126 | out = out['v'].reshape(self.seq_in_len, self.num_nodes, -1, self.dim) 127 | out = out.permute(2, 3, 1, 0) 128 | out = self.end_conv(out) 129 | return out 130 | -------------------------------------------------------------------------------- /proc_data.py: -------------------------------------------------------------------------------- 1 | from dataset.data import * 2 | 3 | batch_size = 64 4 | num_nodes = 170 5 | path='./data/PEMS08/PEMS08.npz' 6 | adjpath='./data/PEMS08/PEMS08.csv' 7 | idpath='./data/PEMS08/PEMS08.txt' 8 | op = PemsData(num_nodes, path, adjpath, None) 9 | op.prcoess('./data/PEMS-D8.pkl') 10 | _, srclist, tgtlist, distlist = op.load_graph() 11 | g, _ = process_t_graph(srclist, tgtlist, distlist, 12, num_nodes, window=12) 12 | file = open('./data/PEMS-D8-Gt.pkl', "wb") 13 | pickle.dump(g, file) 14 | # 15 | # 16 | batch_size = 64 17 | num_nodes = 358 18 | path='./data/PEMS03/PEMS03.npz' 19 | adjpath='./data/PEMS03/PEMS03.csv' 20 | idpath='./data/PEMS03/PEMS03.txt' 21 | op = PemsData(num_nodes, path, adjpath, idpath) 22 | op.prcoess('./data/PEMS-D3.pkl') 23 | _, srclist, tgtlist, distlist = op.load_graph() 24 | g, _ = process_t_graph(srclist, tgtlist, distlist, 12, num_nodes, window=12) 25 | file = open('./data/PEMS-D3-Gt.pkl', "wb") 26 | pickle.dump(g, file) 27 | # 28 | batch_size = 64 29 | num_nodes = 307 30 | path='./data/PEMS04/PEMS04.npz' 31 | adjpath='./data/PEMS04/PEMS04.csv' 32 | idpath='./data/PEMS04/PEMS04.txt' 33 | op = PemsData(num_nodes, path, adjpath, None) 34 | op.prcoess('./data/PEMS-D4.pkl') 35 | _, srclist, tgtlist, distlist = op.load_graph() 36 | g, _ = process_t_graph(srclist, tgtlist, distlist, 12, num_nodes, window=12) 37 | file = open('./data/PEMS-D4-Gt.pkl', "wb") 38 | pickle.dump(g, file) 39 | # 40 | -------------------------------------------------------------------------------- /proc_new_data.py: -------------------------------------------------------------------------------- 1 | from dataset.data import * 2 | 3 | # num_nodes = 20 4 | # op = PoxData("./dataset1/chickenpox.json",'X') 5 | # op.prcoess('./pox.pkl') 6 | # _, srclist, tgtlist, distlist = op.load_graph1() 7 | # g, _ = process_t_graph(srclist, tgtlist, distlist, 12, num_nodes, window=12) 8 | # file = open('./pox-Gt.pkl', "wb") 9 | # pickle.dump(g, file) 10 | # 11 | # num_nodes = 319 12 | # op = WindmillData("./dataset1/windmill_output.json") 13 | # op.prcoess('./windmill.pkl') 14 | # _, srclist, tgtlist, distlist = op.load_graph() 15 | # g, _ = process_t_graph(srclist, tgtlist, distlist, 12, num_nodes, window=12) 16 | # file = open('./windmill-Gt.pkl', "wb") 17 | # pickle.dump(g, file) 18 | 19 | 20 | #num_nodes = 321 21 | #op = MulData("./data/electricity/electricity.txt") 22 | #_, srclist, tgtlist, distlist = op.prcoess('./data/electricity.pkl') 23 | #g, _ = process_t_graph(srclist, tgtlist, distlist, 12, num_nodes, window=12) 24 | #file = open('./data/electricity-Gt.pkl', "wb") 25 | #pickle.dump(g, file) 26 | 27 | num_nodes = 137 28 | op = MulData("./data/solar-energy/solar_AL.txt") 29 | _, srclist, tgtlist, distlist = op.prcoess('./data/solar.pkl') 30 | g, _ = process_t_graph(srclist, tgtlist, distlist, 12, num_nodes, window=12) 31 | file = open('./data/solar-Gt.pkl', "wb") 32 | pickle.dump(g, file) 33 | 34 | #num_nodes = 8 35 | #op = MulData("./data/exchange_rate/exchange_rate.txt") 36 | #_, srclist, tgtlist, distlist = op.prcoess('./data/exchange.pkl') 37 | #g, _ = process_t_graph(srclist, tgtlist, distlist, 12, num_nodes, window=12) 38 | #file = open('./data/exchange-Gt.pkl', "wb") 39 | #pickle.dump(g, file) 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.1 2 | pandas==1.0.1 3 | scipy==1.4.1 4 | dgl==0.5.2 5 | tqdm==4.42.1 6 | torch==1.6.0 7 | -------------------------------------------------------------------------------- /train_electricity.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python main.py --config ./config_electricity/astgcn.json | tee ./log/electricity_astgcn.log 4 | python main.py --config ./config_electricity/dcrnn.json | tee ./log/electricity_dcrnn.log 5 | python main.py --config ./config_electricity/graphwavenet.json | tee ./log/electricity_graphwavenet.log 6 | python main.py --config ./config_electricity/gru.json | tee ./log/electricity_gru.log 7 | python main.py --config ./config_electricity/stgcn.json | tee ./log/electricity_stgcn.log 8 | python main.py --config ./config_electricity/traversenet.json | tee ./log/electricity_traversenet.log 9 | -------------------------------------------------------------------------------- /train_solar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python main.py --config ./config_solar/graphwavenet.json | tee ./log/solar_graphwavenet.log 3 | python main.py --config ./config_solar/astgcn.json | tee ./log/solar_astgcn.log 4 | python main.py --config ./config_solar/dcrnn.json | tee ./log/solar_dcrnn.log 5 | #python main.py --config ./config_solar/gru.json | tee ./log/solar_gru.log 6 | #python main.py --config ./config_solar/stgcn.json | tee ./log/solar_stgcn.log 7 | #python main.py --config ./config_solar/traversenet.json | tee ./log/solar_traversenet.log 8 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnzhan/TraverseNet/ba4ce7478386cb478293f5283a94c40bacdec0cc/trainer/__init__.py -------------------------------------------------------------------------------- /trainer/ctrainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import metrics 3 | import numpy as np 4 | import time 5 | import pickle 6 | import dgl 7 | from utils.process import * 8 | class CTrainer: 9 | # a trainer for models that make predictions once for all future steps. 10 | def __init__(self, model, optimizer, loss, dataloader, params, seq_out_len, scaler, device): 11 | self.model = model 12 | self.model.to(device) 13 | self.dataloader = dataloader 14 | self.scaler = scaler 15 | self.device = device 16 | 17 | self.optimizer = optimizer 18 | # self.lr_scheduler = lr_scheduler 19 | self.loss = loss 20 | self.regloss = torch.nn.BCELoss() 21 | 22 | self.clip = params['clip'] 23 | self.print_every = params['print_every'] 24 | self.seq_out_len = seq_out_len 25 | self.params = params 26 | 27 | 28 | def train_epoch(self): 29 | train_loss = [] 30 | train_mape = [] 31 | train_rmse = [] 32 | t1 = time.time() 33 | 34 | self.dataloader['train_loader'].shuffle() 35 | for iter, (x, y) in enumerate(self.dataloader['train_loader'].get_iterator()): 36 | trainx = torch.Tensor(x).to(self.device) 37 | trainx = trainx.transpose(1, 3) 38 | 39 | trainy = torch.Tensor(y).to(self.device) 40 | trainy = trainy.transpose(1, 3)[:,:,:,:self.seq_out_len] 41 | 42 | metrics = self.train(trainx, trainy[:, self.params['out_level'], :, :]) 43 | 44 | train_loss.append(metrics[0]) 45 | train_mape.append(metrics[1]) 46 | train_rmse.append(metrics[2]) 47 | if iter % self.print_every == 0: 48 | log = 'Iter: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}' 49 | print(log.format(iter, train_loss[-1], train_mape[-1], train_rmse[-1]), flush=True) 50 | t2 = time.time() 51 | return np.mean(train_loss),np.mean(train_mape),np.mean(train_rmse), t2-t1 52 | 53 | def val_epoch(self): 54 | valid_loss = [] 55 | valid_mape = [] 56 | valid_rmse = [] 57 | 58 | t1 = time.time() 59 | for iter, (x, y) in enumerate(self.dataloader['val_loader'].get_iterator()): 60 | testx = torch.Tensor(x).to(self.device) 61 | testx = testx.transpose(1, 3) 62 | testy = torch.Tensor(y).to(self.device) 63 | testy = testy.transpose(1, 3)[:,:,:,:self.seq_out_len] 64 | with torch.no_grad(): 65 | metrics = self.eval(testx, testy[:, self.params['out_level'],:,:]) 66 | 67 | valid_loss.append(metrics[0]) 68 | valid_mape.append(metrics[1]) 69 | valid_rmse.append(metrics[2]) 70 | t2 = time.time() 71 | 72 | return np.mean(valid_loss),np.mean(valid_mape),np.mean(valid_rmse), t2-t1 73 | 74 | 75 | def train(self, x, real_val): 76 | self.model.train() 77 | self.optimizer.zero_grad() 78 | dummy = torch.zeros(10).requires_grad_() 79 | output = self.model(x,dummy) 80 | output = output.transpose(1, 3) 81 | real = torch.unsqueeze(real_val,dim=1) 82 | real = self.scaler.inverse_transform(real,self.params['out_level']) 83 | predict = self.scaler.inverse_transform(output,self.params['out_level']) 84 | loss = self.loss(predict, real, 0.0) 85 | loss.backward() 86 | if self.clip is not None: 87 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) 88 | 89 | self.optimizer.step() 90 | #self.lr_scheduler.step() 91 | # mae = util.masked_mae(predict,real,0.0).item() 92 | mape = metrics.masked_mape(predict,real).item() 93 | rmse = metrics.masked_rmse(predict,real).item() 94 | return loss.item(),mape,rmse 95 | 96 | 97 | def eval(self, x, real_val): 98 | self.model.eval() 99 | dummy = torch.zeros(10).requires_grad_() 100 | output = self.model(x,dummy) 101 | output = output.transpose(1, 3) 102 | real = torch.unsqueeze(real_val,dim=1) 103 | real = self.scaler.inverse_transform(real,self.params['out_level']) 104 | predict = self.scaler.inverse_transform(output,self.params['out_level']) 105 | loss = self.loss(predict, real) 106 | mape = metrics.masked_mape(predict,real).item() 107 | rmse = metrics.masked_rmse(predict,real).item() 108 | return loss.item(),mape,rmse 109 | 110 | 111 | def ev_valid(self,name): 112 | self.model.eval() 113 | outputs = [] 114 | realy = [] 115 | dummy = torch.zeros(10).requires_grad_() 116 | for iter, (x, y) in enumerate(self.dataloader[name+'_loader'].get_iterator()): 117 | testx = torch.Tensor(x).to(self.device) 118 | testx = testx.transpose(1, 3) 119 | testy = torch.Tensor(y).to(self.device) 120 | testy = testy.transpose(1, 3)[:,:,:,:self.seq_out_len] 121 | realy.append(testy[:,self.params['out_level'],:,:].squeeze()) 122 | 123 | 124 | with torch.no_grad(): 125 | preds = self.model(testx,dummy) 126 | preds = preds.transpose(1, 3) 127 | outputs.append(preds.squeeze()) 128 | 129 | yhat = torch.cat(outputs, dim=0) 130 | realy = torch.cat(realy, dim=0) 131 | 132 | pred = self.scaler.inverse_transform(yhat,self.params['out_level']) 133 | realy = self.scaler.inverse_transform(realy,self.params['out_level']) 134 | mae, mape, rmse = metrics.metric(pred, realy) 135 | return mae, mape, rmse 136 | 137 | def ev_test(self, name): 138 | self.model.eval() 139 | outputs = [] 140 | realy = [] 141 | dummy = torch.zeros(10).requires_grad_() 142 | 143 | for iter, (x, y) in enumerate(self.dataloader[name+'_loader'].get_iterator()): 144 | testx = torch.Tensor(x).to(self.device) 145 | testx = testx.transpose(1, 3) 146 | testy = torch.Tensor(y).to(self.device) 147 | testy = testy.transpose(1, 3)[:,:,:,:self.seq_out_len] 148 | realy.append(testy[:,self.params['out_level'],:,:].squeeze(dim=1)) 149 | 150 | with torch.no_grad(): 151 | preds = self.model(testx,dummy) 152 | preds = preds.transpose(1, 3) 153 | outputs.append(preds.squeeze(dim=1)) 154 | 155 | yhat = torch.cat(outputs, dim=0) 156 | realy = torch.cat(realy, dim=0) 157 | 158 | 159 | mae = [] 160 | mape = [] 161 | rmse = [] 162 | for i in range(self.seq_out_len): 163 | pred = self.scaler.inverse_transform(yhat[:, :, i],self.params['out_level']) 164 | real = realy[:, :, i] 165 | real = self.scaler.inverse_transform(real,self.params['out_level']) 166 | results = metrics.metric(pred, real) 167 | log = 'Evaluate best model on ' + name +' data for horizon {:d}, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}' 168 | print(log.format(i + 1, results[0], results[1], results[2])) 169 | mae.append(results[0]) 170 | mape.append(results[1]) 171 | rmse.append(results[2]) 172 | return mae, mape, rmse 173 | -------------------------------------------------------------------------------- /trainer/rtrainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | from utils import metrics 5 | import torch.optim as optim 6 | 7 | 8 | class RTrainer: 9 | # a trainer for models that make predictions sequentially. 10 | def __init__(self, model, optimizer, lr_scheduler, loss, dataloader, params, net_params, scaler, device): 11 | self.model = model.to(device) 12 | self.optimizer = optimizer 13 | self.lr_scheduler = lr_scheduler 14 | self.scaler = scaler 15 | self.loss = loss 16 | 17 | self.seq_in_len = net_params['seq_in_len'] 18 | self.seq_out_len = net_params['seq_out_len'] 19 | self.num_nodes = net_params['num_nodes'] 20 | self.in_dim = net_params['in_dim'] 21 | self.out_dim = net_params['out_dim'] 22 | 23 | self.clip = params['clip'] 24 | self.print_every = params['print_every'] 25 | self.dataloader = dataloader 26 | self.device = device 27 | self.batches_seen = 0 28 | self.params = params 29 | 30 | def train_epoch(self): 31 | train_loss = [] 32 | train_mape = [] 33 | train_rmse = [] 34 | t1 = time.time() 35 | self.dataloader['train_loader'].shuffle() 36 | train_iterator = self.dataloader['train_loader'].get_iterator() 37 | for iter, (x, y) in enumerate(train_iterator): 38 | x, y = self._prepare_data(x, y) 39 | metrics = self.train(x,y,self.batches_seen) 40 | train_loss.append(metrics[0]) 41 | train_mape.append(metrics[1]) 42 | train_rmse.append(metrics[2]) 43 | if iter % self.print_every == 0: 44 | log = 'Iter: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}' 45 | print(log.format(iter, train_loss[-1], train_mape[-1], train_rmse[-1]), flush=True) 46 | self.batches_seen += 1 47 | self.lr_scheduler.step() 48 | 49 | t2 = time.time() 50 | return np.mean(train_loss),np.mean(train_mape),np.mean(train_rmse), t2-t1 51 | 52 | def val_epoch(self): 53 | valid_loss = [] 54 | valid_mape = [] 55 | valid_rmse = [] 56 | 57 | t1 = time.time() 58 | val_iterator = self.dataloader['val_loader'].get_iterator() 59 | 60 | for _, (x, y) in enumerate(val_iterator): 61 | x, y = self._prepare_data(x, y) 62 | with torch.no_grad(): 63 | metrics = self.eval(x, y) 64 | 65 | valid_loss.append(metrics[0]) 66 | valid_mape.append(metrics[1]) 67 | valid_rmse.append(metrics[2]) 68 | 69 | t2 = time.time() 70 | return np.mean(valid_loss), np.mean(valid_mape), np.mean(valid_rmse), t2 - t1 71 | 72 | def train(self, x, y, batches_seen): 73 | self.model.train() 74 | self.optimizer.zero_grad() 75 | output = self.model(x, y, batches_seen) 76 | if batches_seen == 0: 77 | self.optimizer = optim.Adam(self.model.parameters(), lr=0.01, eps=1e-3) 78 | self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[20, 30, 40, 50], gamma=0.1) 79 | y = y[...,self.params['out_level']].squeeze() 80 | y_true = self.scaler.inverse_transform(y, self.params['out_level']) 81 | y_predicted = self.scaler.inverse_transform(output,self.params['out_level']) 82 | loss = self.loss(y_predicted, y_true) 83 | loss.backward() 84 | 85 | if self.clip is not None: 86 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) 87 | self.optimizer.step() 88 | 89 | mape = metrics.masked_mape(y_predicted,y_true).item() 90 | rmse = metrics.masked_rmse(y_predicted,y_true).item() 91 | return loss.item(),mape,rmse 92 | 93 | def eval(self, x, y): 94 | self.model.eval() 95 | output = self.model(x,y) 96 | y = y[...,self.params['out_level']].squeeze() 97 | y_true = self.scaler.inverse_transform(y,self.params['out_level']) 98 | y_predicted = self.scaler.inverse_transform(output,self.params['out_level']) 99 | loss = self.loss(y_predicted, y_true) 100 | mape = metrics.masked_mape(y_predicted,y_true).item() 101 | rmse = metrics.masked_rmse(y_predicted,y_true).item() 102 | return loss.item(),mape,rmse 103 | 104 | 105 | def ev_valid(self,name): 106 | self.model.eval() 107 | y_preds = [] 108 | y_truths = [] 109 | for iter, (x, y) in enumerate(self.dataloader[name+'_loader'].get_iterator()): 110 | x, y = self._prepare_data(x, y) 111 | 112 | with torch.no_grad(): 113 | preds = self.model(x,y) 114 | y_preds.append(preds) 115 | y = y[..., self.params['out_level']].squeeze() 116 | y_truths.append(y) 117 | 118 | y_preds = torch.cat(y_preds, axis=1) 119 | y_truths = torch.cat(y_truths, axis=1) 120 | y_preds = self.scaler.inverse_transform(y_preds,self.params['out_level']) 121 | y_truths = self.scaler.inverse_transform(y_truths,self.params['out_level']) 122 | mae, mape, rmse = metrics.metric(y_preds, y_truths) 123 | return mae, mape, rmse 124 | 125 | def ev_test(self, name): 126 | self.model.eval() 127 | y_preds = [] 128 | y_truths = [] 129 | for iter, (x, y) in enumerate(self.dataloader[name+'_loader'].get_iterator()): 130 | x, y = self._prepare_data(x, y) 131 | 132 | with torch.no_grad(): 133 | preds = self.model(x,y) 134 | y_preds.append(preds) 135 | y = y[..., self.params['out_level']].squeeze() 136 | y_truths.append(y) 137 | 138 | y_preds = torch.cat(y_preds, axis=1) 139 | y_truths = torch.cat(y_truths, axis=1) 140 | 141 | mae = [] 142 | mape = [] 143 | rmse = [] 144 | for i in range(self.seq_out_len): 145 | pred = self.scaler.inverse_transform(y_preds[i,...],self.params['out_level']) 146 | real = self.scaler.inverse_transform(y_truths[i,...],self.params['out_level']) 147 | results = metrics.metric(pred, real) 148 | log = 'Evaluate best model on test data for horizon {:d}, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}' 149 | print(log.format(i + 1, results[0], results[1], results[2])) 150 | mae.append(results[0]) 151 | mape.append(results[1]) 152 | rmse.append(results[2]) 153 | return mae, mape, rmse 154 | 155 | 156 | def _prepare_data(self, x, y): 157 | x, y = self._get_x_y(x, y) 158 | x, y = self._get_x_y_in_correct_dims(x, y) 159 | return x.to(self.device), y.to(self.device) 160 | 161 | def _get_x_y(self, x, y): 162 | """ 163 | :param x: shape (batch_size, seq_len, num_sensor, input_dim) 164 | :param y: shape (batch_size, horizon, num_sensor, input_dim) 165 | :returns x shape (seq_len, batch_size, num_sensor, input_dim) 166 | y shape (horizon, batch_size, num_sensor, input_dim) 167 | """ 168 | x = torch.from_numpy(x).float() 169 | y = torch.from_numpy(y).float() 170 | x = x.permute(1, 0, 2, 3) 171 | y = y.permute(1, 0, 2, 3) 172 | return x, y 173 | 174 | def _get_x_y_in_correct_dims(self, x, y): 175 | """ 176 | :param x: shape (seq_len, batch_size, num_sensor, input_dim) 177 | :param y: shape (horizon, batch_size, num_sensor, input_dim) 178 | :return: x: shape (seq_len, batch_size, num_sensor * input_dim) 179 | y: shape (horizon, batch_size, num_sensor * output_dim) 180 | """ 181 | batch_size = x.size(1) 182 | x = x[-self.seq_in_len:,:,:,:self.in_dim] 183 | x = x.view(self.seq_in_len, batch_size, self.num_nodes * self.in_dim) 184 | 185 | y = y[-self.seq_out_len:,:,:,:self.in_dim] 186 | y = y.view(self.seq_out_len, batch_size, self.num_nodes, self.in_dim) 187 | 188 | # y = y[..., :1].view(-1, batch_size, 189 | # self.num_nodes) 190 | # return x, y[:self.seq_out_len,...] 191 | return x, y 192 | -------------------------------------------------------------------------------- /trainer/tg_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import metrics 3 | import numpy as np 4 | import time 5 | class TGtrainer: 6 | #to do: to be deleted from project 7 | def __init__(self, model, optimizer, lr_scheduler, loss, dataloader, params, seq_out_len, scaler, device): 8 | self.model = model 9 | self.model.to(device) 10 | self.dataloader = dataloader 11 | self.scaler = scaler 12 | self.device = device 13 | 14 | self.optimizer = optimizer 15 | self.lr_scheduler = lr_scheduler 16 | self.loss = loss 17 | self.clip = params['clip'] 18 | self.print_every = params['print_every'] 19 | self.seq_out_len = seq_out_len 20 | self.batches_seen = 0 21 | 22 | def train_epoch(self): 23 | train_loss = [] 24 | train_mape = [] 25 | train_rmse = [] 26 | t1 = time.time() 27 | 28 | self.dataloader['train_loader'].shuffle() 29 | for iter, (x, y) in enumerate(self.dataloader['train_loader'].get_iterator()): 30 | 31 | trainx = torch.Tensor(x).to(self.device) 32 | trainx = trainx.transpose(1, 3) 33 | trainy = torch.Tensor(y).to(self.device) 34 | trainy = trainy.transpose(1, 3)[:,:,:,:self.seq_out_len] 35 | metrics = self.train(trainx, trainy, self.batches_seen) 36 | 37 | train_loss.append(metrics[0]) 38 | train_mape.append(metrics[1]) 39 | train_rmse.append(metrics[2]) 40 | if iter % self.print_every == 0: 41 | log = 'Iter: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}' 42 | print(log.format(iter, train_loss[-1], train_mape[-1], train_rmse[-1]), flush=True) 43 | self.batches_seen +=1 44 | self.lr_scheduler.step() 45 | 46 | t2 = time.time() 47 | return np.mean(train_loss),np.mean(train_mape),np.mean(train_rmse), t2-t1 48 | 49 | def val_epoch(self): 50 | valid_loss = [] 51 | valid_mape = [] 52 | valid_rmse = [] 53 | 54 | t1 = time.time() 55 | for iter, (x, y) in enumerate(self.dataloader['val_loader'].get_iterator()): 56 | testx = torch.Tensor(x).to(self.device) 57 | testx = testx.transpose(1, 3) 58 | testy = torch.Tensor(y).to(self.device) 59 | testy = testy.transpose(1, 3)[:,:,:,:self.seq_out_len] 60 | 61 | with torch.no_grad(): 62 | metrics = self.eval(testx, testy) 63 | 64 | valid_loss.append(metrics[0]) 65 | valid_mape.append(metrics[1]) 66 | valid_rmse.append(metrics[2]) 67 | t2 = time.time() 68 | 69 | return np.mean(valid_loss),np.mean(valid_mape),np.mean(valid_rmse), t2-t1 70 | 71 | 72 | def train(self, src, tgt, batches_seen): 73 | self.model.train() 74 | self.optimizer.zero_grad() 75 | src_x = {'v': src} 76 | tgt_x = {'v': tgt} 77 | dummy = torch.zeros(10).requires_grad_() 78 | 79 | predict = self.model(src_x, tgt_x, dummy, batches_seen) 80 | real = torch.unsqueeze(tgt[:,0,:,:],dim=1) 81 | real = self.scaler.inverse_transform(real) 82 | predict = self.scaler.inverse_transform(predict) 83 | 84 | loss = self.loss(predict, real, 0.0) 85 | loss.backward() 86 | if self.clip is not None: 87 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) 88 | 89 | self.optimizer.step() 90 | # mae = util.masked_mae(predict,real,0.0).item() 91 | mape = metrics.masked_mape(predict,real).item() 92 | rmse = metrics.masked_rmse(predict,real).item() 93 | return loss.item(),mape,rmse 94 | 95 | 96 | def eval(self, src, tgt): 97 | self.model.eval() 98 | src_x = {'v': src} 99 | tgt_x = {'v': tgt} 100 | dummy = torch.zeros(10).requires_grad_() 101 | predict = self.model(src_x, tgt_x, dummy) 102 | real = torch.unsqueeze(tgt[:,0,:,:],dim=1) 103 | real = self.scaler.inverse_transform(real) 104 | predict = self.scaler.inverse_transform(predict) 105 | mae, mape, rmse = metrics.metric(predict, real) 106 | return mae, mape, rmse 107 | 108 | 109 | def ev_valid(self, name): 110 | self.model.eval() 111 | outputs = [] 112 | realy = [] 113 | dummy = torch.zeros(10).requires_grad_() 114 | 115 | for iter, (x, y) in enumerate(self.dataloader[name+'_loader'].get_iterator()): 116 | testx = torch.Tensor(x).to(self.device) 117 | testx = testx.transpose(1, 3) 118 | 119 | testy = torch.Tensor(y).to(self.device) 120 | testy = testy.transpose(1, 3)[:,:,:,:self.seq_out_len] 121 | realy.append(testy[:,0,:,:].squeeze()) 122 | 123 | src_x = {'v': testx} 124 | testy = torch.Tensor(y).to(self.device) 125 | testy = testy.transpose(1, 3)[:,:,:,:self.seq_out_len] 126 | tgt_x = {'v': testy} 127 | with torch.no_grad(): 128 | preds = self.model(src_x, tgt_x, dummy) 129 | outputs.append(preds.squeeze()) 130 | 131 | pred = torch.cat(outputs, dim=0) 132 | realy = torch.cat(realy, dim=0) 133 | 134 | pred = self.scaler.inverse_transform(pred) 135 | realy = self.scaler.inverse_transform(realy) 136 | mae, mape, rmse = metrics.metric(pred, realy) 137 | return mae, mape, rmse 138 | 139 | def ev_test(self,name): 140 | self.model.eval() 141 | outputs = [] 142 | realy = [] 143 | # realy = torch.Tensor(self.dataloader['y_'+name]).to(self.device) 144 | # realy = realy.transpose(1, 3)[:, 0, :, :self.seq_out_len] 145 | dummy = torch.zeros(10).requires_grad_() 146 | 147 | for iter, (x, y) in enumerate(self.dataloader[name+'_loader'].get_iterator()): 148 | testx = torch.Tensor(x).to(self.device) 149 | testx = testx.transpose(1, 3) 150 | src_x = {'v': testx} 151 | 152 | testy = torch.Tensor(y).to(self.device) 153 | testy = testy.transpose(1, 3)[:,:,:,:self.seq_out_len] 154 | realy.append(testy[:,0,:,:].squeeze()) 155 | 156 | tgt_x = {'v': testy} 157 | with torch.no_grad(): 158 | preds = self.model(src_x, tgt_x, dummy) 159 | 160 | outputs.append(preds.squeeze()) 161 | 162 | yhat = torch.cat(outputs, dim=0) 163 | realy = torch.cat(realy, dim=0) 164 | 165 | mae = [] 166 | mape = [] 167 | rmse = [] 168 | for i in range(self.seq_out_len): 169 | pred = self.scaler.inverse_transform(yhat[:, :, i]) 170 | real = realy[:, :, i] 171 | real = self.scaler.inverse_transform(real) 172 | results = metrics.metric(pred, real) 173 | log = 'Evaluate best model on ' + name + ' data for horizon {:d}, MAE: {:.4f}, MAPE: {:.4f}, RMSE: {:.4f}' 174 | print(log.format(i + 1, results[0], results[1], results[2])) 175 | mae.append(results[0]) 176 | mape.append(results[1]) 177 | rmse.append(results[2]) 178 | return mae, mape, rmse 179 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnzhan/TraverseNet/ba4ce7478386cb478293f5283a94c40bacdec0cc/utils/__init__.py -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def masked_huber(preds, labels, null_val=0, beta=1): 6 | if np.isnan(null_val): 7 | mask = ~torch.isnan(labels) 8 | else: 9 | mask = (labels!=null_val) 10 | assert list(preds.shape)==list(labels.shape), "shapes of two inputs are not equal" 11 | mask = mask.float() 12 | mask /= torch.mean((mask)) 13 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 14 | se = (preds-labels)**2 15 | se = se * mask 16 | se = torch.where(torch.isnan(se), torch.zeros_like(se), se) 17 | ae = torch.abs(preds-labels) 18 | ae = ae * mask 19 | ae = torch.where(torch.isnan(ae), torch.zeros_like(ae), ae) 20 | loss = torch.where(aenull_val) 29 | assert list(preds.shape)==list(labels.shape), "shapes of two inputs are not equal" 30 | mask = mask.float() 31 | mask /= torch.mean((mask)) 32 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 33 | loss = (preds-labels)**2 34 | loss = loss * mask 35 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 36 | return torch.mean(loss) 37 | 38 | def masked_rmse(preds, labels, null_val=0.0001): 39 | return torch.sqrt(masked_mse(preds=preds, labels=labels, null_val=null_val)) 40 | 41 | 42 | def masked_mae(preds, labels, null_val=0.0001): 43 | if np.isnan(null_val): 44 | mask = ~torch.isnan(labels) 45 | else: 46 | mask = (labels>null_val) 47 | assert list(preds.shape)==list(labels.shape), "shapes of two inputs are not equal" 48 | mask = mask.float() 49 | mask /= torch.mean((mask)) 50 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 51 | loss = torch.abs(preds-labels) 52 | loss = loss * mask 53 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 54 | return torch.mean(loss) 55 | 56 | def masked_mape(preds, labels, null_val=0.0001): 57 | if np.isnan(null_val): 58 | mask = ~torch.isnan(labels) 59 | else: 60 | mask = (labels>null_val) 61 | assert list(preds.shape)==list(labels.shape), "shapes of two inputs are not equal" 62 | mask = mask.float() 63 | mask /= torch.mean(mask) 64 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 65 | loss = torch.abs(preds-labels)/torch.abs(labels) 66 | loss = loss * mask 67 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 68 | return torch.mean(loss) 69 | 70 | def metric(pred, real): 71 | mae = masked_mae(pred,real).item() 72 | mape = masked_mape(pred,real).item() 73 | rmse = masked_rmse(pred,real).item() 74 | return mae,mape,rmse 75 | 76 | 77 | def masked_mae_loss(y_pred, y_true): 78 | mask = (y_true != 0).float() 79 | mask /= mask.mean() 80 | loss = torch.abs(y_pred - y_true) 81 | loss = loss * mask 82 | # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3 83 | loss[loss != loss] = 0 84 | return loss.mean() -------------------------------------------------------------------------------- /utils/process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | from scipy.sparse import linalg 4 | import pandas as pd 5 | import torch 6 | import tqdm 7 | 8 | def process_t_graph(srclist, tgtlist, dist, T, num_nodes, window=12): 9 | #assign the same edge type to connections between a node's current state and its neigbhor's all historial states. 10 | df = pd.DataFrame({'src': srclist, 'tgt': tgtlist, 'dist': dist}) 11 | df = df.sort_values(by=['tgt','dist'], ignore_index=True) 12 | print(df.loc[0:30]) 13 | dfn = pd.DataFrame(columns=['rel', 'src', 'src_t', 'tgt', 'tgt_t', 'dis']) 14 | rel = dict() 15 | for i in tqdm.trange(len(df)): 16 | if df.loc[i][1] not in rel.keys(): 17 | ctr = 0 18 | rel[df.loc[i][1]] = [str(ctr)] 19 | else: 20 | ctr += 1 21 | dis = df.loc[i][2] 22 | if dis==0: 23 | rela = str(ctr)+"_-1" 24 | ew = 1 25 | else: 26 | rela = str(ctr)+"_1" 27 | ew = dis 28 | 29 | for j in range(T): 30 | s = j-window 31 | if s<0: 32 | s=0 33 | for k in range(s,j+1): 34 | dfn = dfn.append({'rel': rela, 'src': int(df.loc[i][0]), 'src_t': int(k), 'tgt': int(df.loc[i][1]), 'tgt_t': int(j), 'dis': ew}, 35 | ignore_index=True) 36 | if j ==0 and df.loc[i][0] == df.loc[i][1]: 37 | continue 38 | #if df.loc[i][0] == df.loc[i][1]: 39 | dfn = dfn.append({'rel': rela, 'src': int(df.loc[i][1]), 'src_t': int(j), 'tgt': int(df.loc[i][1]), 'tgt_t': int(j), 'dis': ew}, 40 | ignore_index=True) 41 | print(dfn) 42 | g = dict() 43 | ds = dict() 44 | for i in tqdm.trange(len(dfn)): 45 | if dfn.loc[i][0] not in g.keys(): 46 | g[dfn.loc[i][0]] = pd.DataFrame(columns=['rel', 'src', 'tgt']) 47 | ds[dfn.loc[i][0]] = [] 48 | src = dfn.loc[i][2] * num_nodes + dfn.loc[i][1] 49 | tgt = dfn.loc[i][4] * num_nodes + dfn.loc[i][3] 50 | g[dfn.loc[i][0]] = g[dfn.loc[i][0]].append({'rel': dfn.loc[i][0], 'src': src, 'tgt': tgt}, ignore_index=True) 51 | ds[dfn.loc[i][0]].append(dfn.loc[i][5]) 52 | 53 | graph_data = dict() 54 | print('number of relations ', len(g.keys())) 55 | for k in g.keys(): 56 | key = ('v', k, 'v') 57 | value = (torch.tensor(g[k]['src'].to_numpy(dtype=int)), torch.tensor(g[k]['tgt'].to_numpy(dtype=int))) 58 | graph_data[key] = value 59 | ds[k] = torch.tensor(ds[k]) 60 | # g = dgl.heterograph(graph_data) 61 | return graph_data, ds 62 | 63 | 64 | def process_st_graph(srclist, tgtlist, dist, T, num_nodes, window=12): 65 | #construct a convention st graph. 66 | df = pd.DataFrame({'src': srclist, 'tgt': tgtlist, 'dist': dist}) 67 | df = df.sort_values(by=['tgt','dist'], ignore_index=True) 68 | dfn = pd.DataFrame(columns=['rel', 'src', 'src_t', 'tgt', 'tgt_t', 'dis']) 69 | print(df.loc[0:30]) 70 | for i in tqdm.trange(len(df)): 71 | src = int(df.loc[i][0]) 72 | tgt = int(df.loc[i][1]) 73 | for t in range(T): 74 | dfn = dfn.append({'rel': '1', 'src': src, 'src_t': t, 'tgt': tgt, 'tgt_t': t, 'dis': df.loc[i][2]}, 75 | ignore_index=True) 76 | # if src==tgt: 77 | # for j in range(T): 78 | # s = j - window 79 | # if s < 0: 80 | # s = 0 81 | # for k in range(s,j+1): 82 | # dfn = dfn.append({'rel': '0', 'src': src, 'src_t': int(k), 'tgt': tgt, 83 | # 'tgt_t': int(j), 'dis':0}, ignore_index=True) 84 | print(dfn) 85 | g = dict() 86 | ds = dict() 87 | for i in tqdm.trange(len(dfn)): 88 | if dfn.loc[i][0] not in g.keys(): 89 | g[dfn.loc[i][0]] = pd.DataFrame(columns=['rel', 'src', 'tgt']) 90 | ds[dfn.loc[i][0]] = [] 91 | src = dfn.loc[i][2] * num_nodes + dfn.loc[i][1] 92 | tgt = dfn.loc[i][4] * num_nodes + dfn.loc[i][3] 93 | g[dfn.loc[i][0]] = g[dfn.loc[i][0]].append({'rel': dfn.loc[i][0], 'src': src, 'tgt': tgt}, ignore_index=True) 94 | ds[dfn.loc[i][0]].append(dfn.loc[i][5]) 95 | 96 | graph_data = dict() 97 | print('number of relations ', len(g.keys())) 98 | for k in g.keys(): 99 | key = ('v', k, 'v') 100 | value = (torch.tensor(g[k]['src'].to_numpy(dtype=int)), torch.tensor(g[k]['tgt'].to_numpy(dtype=int))) 101 | graph_data[key] = value 102 | ds[k] = torch.tensor(ds[k]) 103 | # g = dgl.heterograph(graph_data) 104 | return graph_data, ds 105 | 106 | 107 | def process_f_graph(srclist, tgtlist, dist, T, num_nodes, window=12): 108 | #to do: a function to be deleted 109 | df = pd.DataFrame({'src': srclist, 'tgt': tgtlist, 'dist': dist}) 110 | df = df.sort_values(by=['tgt','dist'], ignore_index=True) 111 | print(df.loc[0:30]) 112 | dfn = pd.DataFrame(columns=['rel', 'src', 'src_t', 'tgt', 'tgt_t', 'dis']) 113 | rel = dict() 114 | for i in range(len(df)): 115 | print(i) 116 | if df.loc[i][1] not in rel.keys(): 117 | ctr = 0 118 | rel[df.loc[i][1]] = [str(ctr)] 119 | else: 120 | ctr += 1 121 | dis = df.loc[i][2] 122 | if dis==0: 123 | rela = str(ctr)+"_-1" 124 | ew = 1 125 | elif dis>1e8: 126 | rela = str(ctr)+"_0" 127 | ew = 1e-9 - dis 128 | else: 129 | rela = str(ctr)+"_1" 130 | ew = dis 131 | 132 | for j in range(T): 133 | for k in range(T): 134 | dfn = dfn.append({'rel': rela, 'src': int(df.loc[i][0]), 'src_t': int(k), 'tgt': int(df.loc[i][1]), 'tgt_t': int(j), 'dis': ew}, 135 | ignore_index=True) 136 | print(dfn) 137 | g = dict() 138 | ds = dict() 139 | for i in range(len(dfn)): 140 | if dfn.loc[i][0] not in g.keys(): 141 | g[dfn.loc[i][0]] = pd.DataFrame(columns=['rel', 'src', 'tgt']) 142 | ds[dfn.loc[i][0]] = [] 143 | src = dfn.loc[i][2] * num_nodes + dfn.loc[i][1] 144 | tgt = dfn.loc[i][4] * num_nodes + dfn.loc[i][3] 145 | g[dfn.loc[i][0]] = g[dfn.loc[i][0]].append({'rel': dfn.loc[i][0], 'src': src, 'tgt': tgt}, ignore_index=True) 146 | ds[dfn.loc[i][0]].append(dfn.loc[i][5]) 147 | 148 | graph_data = dict() 149 | print('number of relations ', len(g.keys())) 150 | for k in g.keys(): 151 | key = ('v', k, 'v') 152 | value = (torch.tensor(g[k]['src'].to_numpy(dtype=int)), torch.tensor(g[k]['tgt'].to_numpy(dtype=int))) 153 | graph_data[key] = value 154 | ds[k] = torch.tensor(ds[k]) 155 | # g = dgl.heterograph(graph_data) 156 | return graph_data, ds 157 | 158 | 159 | 160 | def process_s_graph(srclist, tgtlist, dist, T, num_nodes, window=12): 161 | # assign the same edge type to connections between a node's current state and all of its neigbhors' current states. 162 | df = pd.DataFrame({'src': srclist, 'tgt': tgtlist, 'dist':dist}) 163 | df = df.sort_values(by=['tgt', 'dist'], ignore_index=True) 164 | print(df.loc[0:30]) 165 | dfn = pd.DataFrame(columns=['rel', 'src', 'src_t', 'tgt', 'tgt_t']) 166 | for i in tqdm.trange(len(df)): 167 | dis = df.loc[i][2] 168 | if dis==0: 169 | rela = "_-1" 170 | elif dis>1e8: 171 | rela = "_0" 172 | else: 173 | rela = "_1" 174 | for j in range(T): 175 | s = j-window 176 | if s<0: 177 | s=0 178 | for k in range(s,j+1): 179 | dfn = dfn.append({'rel': str(j-k)+rela, 'src': int(df.loc[i][0]), 'src_t': int(k), 'tgt': int(df.loc[i][1]), 'tgt_t': int(j)}, 180 | ignore_index=True) 181 | print(dfn) 182 | g = dict() 183 | for i in tqdm.trange(len(dfn)): 184 | if dfn.loc[i][0] not in g.keys(): 185 | g[dfn.loc[i][0]] = pd.DataFrame(columns=['rel', 'src', 'tgt']) 186 | src = dfn.loc[i][2] * num_nodes + dfn.loc[i][1] 187 | tgt = dfn.loc[i][4] * num_nodes + dfn.loc[i][3] 188 | g[dfn.loc[i][0]] = g[dfn.loc[i][0]].append({'rel': dfn.loc[i][0], 'src': src, 'tgt': tgt}, ignore_index=True) 189 | graph_data = dict() 190 | print('number of relations ', len(g.keys())) 191 | for k in g.keys(): 192 | key = ('v', 'r' + str(k), 'v') 193 | value = (torch.tensor(g[k]['src'].to_numpy(dtype=int)), torch.tensor(g[k]['tgt'].to_numpy(dtype=int))) 194 | graph_data[key] = value 195 | # g = dgl.heterograph(graph_data) 196 | return graph_data 197 | 198 | 199 | 200 | def process_a_graph(srclist, tgtlist, T, num_nodes, window=12): 201 | # to do: a function to be deleted. 202 | df = pd.DataFrame({'src': srclist, 'tgt': tgtlist}) 203 | df = df.sort_values(by=['tgt', 'src'], ignore_index=True) 204 | print(df) 205 | dfn = pd.DataFrame(columns=['rel', 'src', 'src_t', 'tgt', 'tgt_t']) 206 | for i in range(len(df)): 207 | print(i) 208 | for j in range(T): 209 | s = j-window 210 | if s<0: 211 | s=0 212 | for k in range(s,j+1): 213 | dfn = dfn.append({'rel': 0, 'src': df.loc[i][0], 'src_t': k, 'tgt': df.loc[i][1], 'tgt_t': j}, 214 | ignore_index=True) 215 | print(dfn) 216 | g = dict() 217 | for i in range(len(dfn)): 218 | if dfn.loc[i][0] not in g.keys(): 219 | g[dfn.loc[i][0]] = pd.DataFrame(columns=['rel', 'src', 'tgt']) 220 | src = dfn.loc[i][2] * num_nodes + dfn.loc[i][1] 221 | tgt = dfn.loc[i][4] * num_nodes + dfn.loc[i][3] 222 | g[dfn.loc[i][0]] = g[dfn.loc[i][0]].append({'rel': dfn.loc[i][0], 'src': src, 'tgt': tgt}, ignore_index=True) 223 | graph_data = dict() 224 | print('number of relations ', len(g.keys())) 225 | for k in g.keys(): 226 | key = ('v', 'r' + str(k), 'v') 227 | value = (torch.tensor(g[k]['src'].to_numpy(dtype=int)), torch.tensor(g[k]['tgt'].to_numpy(dtype=int))) 228 | graph_data[key] = value 229 | # g = dgl.heterograph(graph_data) 230 | return graph_data 231 | 232 | 233 | 234 | def sym_adj(adj): 235 | """Symmetrically normalize adjacency matrix.""" 236 | adj = sp.coo_matrix(adj) 237 | rowsum = np.array(adj.sum(1)) 238 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 239 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 240 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 241 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).astype(np.float32) 242 | 243 | def asym_adj(adj): 244 | adj = sp.coo_matrix(adj) 245 | rowsum = np.array(adj.sum(1)).flatten() 246 | d_inv = np.power(rowsum, -1).flatten() 247 | d_inv[np.isinf(d_inv)] = 0. 248 | d_mat= sp.diags(d_inv) 249 | return d_mat.dot(adj).astype(np.float32) 250 | 251 | def trans_adj(adj): 252 | adj = sp.coo_matrix(adj) 253 | colsum = np.array(adj.sum(0)).flatten() 254 | d_inv = np.power(colsum, -1).flatten() 255 | d_inv[np.isinf(d_inv)] = 0. 256 | d_mat= sp.diags(d_inv) 257 | return adj.dot(d_mat).astype(np.float32) 258 | 259 | def calculate_normalized_laplacian(adj): 260 | """ 261 | # L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2 262 | # D = diag(A 1) 263 | :param adj: 264 | :return: 265 | """ 266 | adj = sp.coo_matrix(adj) 267 | d = np.array(adj.sum(1)) 268 | d_inv_sqrt = np.power(d, -0.5).flatten() 269 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 270 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 271 | normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 272 | return normalized_laplacian 273 | 274 | 275 | def calculate_random_walk_matrix(adj_mx): 276 | adj_mx = sp.coo_matrix(adj_mx) 277 | d = np.array(adj_mx.sum(1)) 278 | d_inv = np.power(d, -1).flatten() 279 | d_inv[np.isinf(d_inv)] = 0. 280 | d_mat_inv = sp.diags(d_inv) 281 | random_walk_mx = d_mat_inv.dot(adj_mx).tocoo() 282 | return random_walk_mx 283 | 284 | 285 | def calculate_reverse_random_walk_matrix(adj_mx): 286 | return calculate_random_walk_matrix(np.transpose(adj_mx)) 287 | 288 | 289 | def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True): 290 | if undirected: 291 | adj_mx = np.maximum.reduce([adj_mx, adj_mx.T]) 292 | L = calculate_normalized_laplacian(adj_mx) 293 | if lambda_max is None: 294 | lambda_max, _ = linalg.eigsh(L, 1, which='LM') 295 | lambda_max = lambda_max[0] 296 | L = sp.csr_matrix(L) 297 | M, _ = L.shape 298 | I = sp.identity(M, format='csr', dtype=L.dtype) 299 | L = (2 / lambda_max * L) - I 300 | return L.astype(np.float32) 301 | 302 | 303 | def scaled_Laplacian(W): 304 | ''' 305 | compute \tilde{L} 306 | Parameters 307 | ---------- 308 | W: np.ndarray, shape is (N, N), N is the num of vertices 309 | Returns 310 | ---------- 311 | scaled_Laplacian: np.ndarray, shape (N, N) 312 | ''' 313 | 314 | assert W.shape[0] == W.shape[1] 315 | 316 | D = np.diag(np.sum(W, axis=1)) 317 | 318 | L = D - W 319 | 320 | lambda_max = linalg.eigs(L, k=1, which='LR')[0].real 321 | return (2 * L) / lambda_max - np.identity(W.shape[0]) 322 | 323 | 324 | def cheb_polynomial(L_tilde, K): 325 | ''' 326 | compute a list of chebyshev polynomials from T_0 to T_{K-1} 327 | Parameters 328 | ---------- 329 | L_tilde: scaled Laplacian, np.ndarray, shape (N, N) 330 | K: the maximum order of chebyshev polynomials 331 | Returns 332 | ---------- 333 | cheb_polynomials: list(np.ndarray), length: K, from T_0 to T_{K-1} 334 | ''' 335 | 336 | N = L_tilde.shape[0] 337 | 338 | cheb_polynomials = [np.identity(N), L_tilde.copy()] 339 | 340 | for i in range(2, K): 341 | cheb_polynomials.append(2 * L_tilde * cheb_polynomials[i - 1] - cheb_polynomials[i - 2]) 342 | 343 | return cheb_polynomials 344 | --------------------------------------------------------------------------------