├── LICENSE ├── README.md ├── data ├── metr-la.h5.zip ├── model │ ├── README.md │ ├── para_bay.yaml │ └── para_la.yaml ├── pems-bay.h5.zip └── sensor_graph │ ├── adj_mx.pkl │ ├── adj_mx_bay.pkl │ ├── distances_la_2012.csv │ ├── graph_sensor_ids.txt │ └── graph_sensor_locations.csv ├── lib ├── AMSGrad.py ├── __init__.py ├── metrics.py ├── metrics_test.py └── utils.py ├── model ├── __init__.py └── pytorch │ ├── __init__.py │ ├── cell.py │ ├── loss.py │ ├── model.py │ └── supervisor.py ├── requirements.txt ├── scripts ├── __init__.py ├── eval_baseline_methods.py ├── gen_adj_mx.py ├── generate_training_data.py └── generate_visualization_data.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Discrete Graph Structure Learning for Forecasting Multiple Time Series 2 | 3 | This is a PyTorch implementation of the paper "[Discrete Graph Structure Learning for Forecasting Multiple Time Series](https://openreview.net/pdf?id=WEHSlH5mOk)", ICLR 2021. 4 | 5 | ## Installation 6 | 7 | Install the dependency using the following command: 8 | 9 | ```bash 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | * torch 14 | * scipy>=0.19.0 15 | * numpy>=1.12.1 16 | * pandas>=0.19.2 17 | * pyyaml 18 | * statsmodels 19 | * tensorflow>=1.3.0 20 | * tables 21 | * future 22 | 23 | 24 | ## Data Preparation 25 | 26 | The traffic data files for Los Angeles (METR-LA) and the Bay Area (PEMS-BAY) are put into the `data/` folder. They are provided by [DCRNN](https://github.com/chnsh/DCRNN_PyTorch). 27 | 28 | Run the following commands to generate train/test/val dataset at `data/{METR-LA,PEMS-BAY}/{train,val,test}.npz`. 29 | ```bash 30 | # Unzip the datasets 31 | unzip data/metr-la.h5.zip -d data/ 32 | unzip data/pems-bay.h5.zip -d data/ 33 | 34 | # Create data directories 35 | mkdir -p data/{METR-LA,PEMS-BAY} 36 | 37 | # METR-LA 38 | python -m scripts.generate_training_data --output_dir=data/METR-LA --traffic_df_filename=data/metr-la.h5 39 | 40 | # PEMS-BAY 41 | python -m scripts.generate_training_data --output_dir=data/PEMS-BAY --traffic_df_filename=data/pems-bay.h5 42 | ``` 43 | 44 | ## Train Model 45 | 46 | When you train the model, you can run: 47 | 48 | ```bash 49 | # Use METR-LA dataset 50 | python train.py --config_filename=data/model/para_la.yaml --temperature=0.5 51 | 52 | # Use PEMS-BAY dataset 53 | python train.py --config_filename=data/model/para_bay.yaml --temperature=0.5 54 | ``` 55 | 56 | Hyperparameters can be modified in the `para_la.yaml` and `para_bay.yaml` files. 57 | 58 | ## Design your own model 59 | 60 | You can directly modify the model in the "model/pytorch/model.py" file. 61 | 62 | ## Citation 63 | 64 | If you use this repository, e.g., the code and the datasets, in your research, please cite the following paper: 65 | ``` 66 | @article{shang2021discrete, 67 | title={Discrete Graph Structure Learning for Forecasting Multiple Time Series}, 68 | author={Shang, Chao and Chen, Jie and Bi, Jinbo}, 69 | journal={arXiv preprint arXiv:2101.06861}, 70 | year={2021} 71 | } 72 | ``` 73 | 74 | ## Acknowledgments 75 | 76 | [DCRNN-PyTorch](https://github.com/chnsh/DCRNN_PyTorch), [GCN](https://github.com/tkipf/gcn), [NRI](https://github.com/ethanfetaya/NRI) and [LDS-GNN](https://github.com/lucfra/LDS-GNN). 77 | -------------------------------------------------------------------------------- /data/metr-la.h5.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoshangcs/GTS/8ed45ff1476639f78c382ff09ecca8e60523e7ce/data/metr-la.h5.zip -------------------------------------------------------------------------------- /data/model/README.md: -------------------------------------------------------------------------------- 1 | The "yaml" files include the parameters used in the model. The model performance is sensitive to the hyperparameters of "base_lr" and "steps" . Please finetune them. 2 | -------------------------------------------------------------------------------- /data/model/para_bay.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | base_dir: data/model 3 | log_level: INFO 4 | data: 5 | batch_size: 64 6 | dataset_dir: data/PEMS-BAY 7 | test_batch_size: 64 8 | val_batch_size: 64 9 | graph_pkl_filename: data/sensor_graph/adj_mx_bay.pkl 10 | 11 | model: 12 | cl_decay_steps: 2000 13 | filter_type: dual_random_walk 14 | horizon: 12 15 | input_dim: 2 16 | l1_decay: 0 17 | max_diffusion_step: 2 18 | num_nodes: 325 19 | num_rnn_layers: 1 20 | output_dim: 1 21 | rnn_units: 128 22 | seq_len: 12 23 | use_curriculum_learning: true 24 | dim_fc: 583408 25 | 26 | train: 27 | base_lr: 0.001 28 | dropout: 0 29 | epoch: 0 30 | epochs: 200 31 | epsilon: 1.0e-3 32 | global_step: 0 33 | lr_decay_ratio: 0.1 34 | max_grad_norm: 5 35 | max_to_keep: 100 36 | min_learning_rate: 2.0e-06 37 | optimizer: adam 38 | patience: 100 39 | steps: [20, 30, 40] 40 | test_every_n_epochs: 5 41 | knn_k: 30 42 | epoch_use_regularization: 200 43 | num_sample: 10 44 | -------------------------------------------------------------------------------- /data/model/para_la.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | base_dir: data/model 3 | log_level: INFO 4 | data: 5 | batch_size: 64 6 | dataset_dir: data/METR-LA 7 | test_batch_size: 64 8 | val_batch_size: 64 9 | graph_pkl_filename: data/sensor_graph/adj_mx.pkl 10 | 11 | model: 12 | cl_decay_steps: 2000 13 | filter_type: dual_random_walk 14 | horizon: 12 15 | input_dim: 2 16 | l1_decay: 0 17 | max_diffusion_step: 3 18 | num_nodes: 207 19 | num_rnn_layers: 1 20 | output_dim: 1 21 | rnn_units: 64 22 | seq_len: 12 23 | use_curriculum_learning: true 24 | dim_fc: 383552 25 | 26 | train: 27 | base_lr: 0.005 28 | dropout: 0 29 | epoch: 0 30 | epochs: 200 31 | epsilon: 1.0e-3 32 | global_step: 0 33 | lr_decay_ratio: 0.1 34 | max_grad_norm: 5 35 | max_to_keep: 100 36 | min_learning_rate: 2.0e-06 37 | optimizer: adam 38 | patience: 100 39 | steps: [20, 30, 40] 40 | test_every_n_epochs: 5 41 | knn_k: 10 42 | epoch_use_regularization: 200 43 | num_sample: 10 44 | -------------------------------------------------------------------------------- /data/pems-bay.h5.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoshangcs/GTS/8ed45ff1476639f78c382ff09ecca8e60523e7ce/data/pems-bay.h5.zip -------------------------------------------------------------------------------- /data/sensor_graph/graph_sensor_ids.txt: -------------------------------------------------------------------------------- 1 | 773869,767541,767542,717447,717446,717445,773062,767620,737529,717816,765604,767471,716339,773906,765273,716331,771667,716337,769953,769402,769403,769819,769405,716941,717578,716960,717804,767572,767573,773012,773013,764424,769388,716328,717819,769941,760987,718204,718045,769418,768066,772140,773927,760024,774012,774011,767609,769359,760650,716956,769831,761604,717495,716554,773953,767470,716955,764949,773954,767366,769444,773939,774067,769443,767750,767751,767610,773880,764766,717497,717490,717491,717492,717493,765176,717498,717499,765171,718064,718066,765164,769431,769430,717610,767053,767621,772596,772597,767350,767351,716571,773023,767585,773024,717483,718379,717481,717480,717486,764120,772151,718371,717489,717488,717818,718076,718072,767455,767454,761599,717099,773916,716968,769467,717576,717573,717572,717571,717570,764760,718089,769847,717608,767523,716942,718090,769867,717472,717473,759591,764781,765099,762329,716953,716951,767509,765182,769358,772513,716958,718496,769346,773904,718499,764853,761003,717502,759602,717504,763995,717508,765265,773996,773995,717469,717468,764106,717465,764794,717466,717461,717460,717463,717462,769345,716943,772669,717582,717583,717580,716949,717587,772178,717585,716939,768469,764101,767554,773975,773974,717510,717513,717825,767495,767494,717821,717823,717458,717459,769926,764858,717450,717452,717453,759772,717456,771673,772167,769372,774204,769806,717590,717592,717595,772168,718141,769373 -------------------------------------------------------------------------------- /data/sensor_graph/graph_sensor_locations.csv: -------------------------------------------------------------------------------- 1 | index,sensor_id,latitude,longitude 2 | 0,773869,34.15497,-118.31829 3 | 1,767541,34.11621,-118.23799 4 | 2,767542,34.11641,-118.23819 5 | 3,717447,34.07248,-118.26772 6 | 4,717446,34.07142,-118.26572 7 | 5,717445,34.06913,-118.25932 8 | 6,773062,34.05368,-118.23369 9 | 7,767620,34.13486,-118.22932 10 | 8,737529,34.20264,-118.47352 11 | 9,717816,34.15562,-118.46860 12 | 10,765604,34.16415,-118.38223 13 | 11,767471,34.15691,-118.22469 14 | 12,716339,34.07821,-118.28795 15 | 13,773906,34.15660,-118.30266 16 | 14,765273,34.18949,-118.47437 17 | 15,716331,34.07006,-118.26246 18 | 16,771667,34.07314,-118.23388 19 | 17,716337,34.07732,-118.28186 20 | 18,769953,34.20672,-118.19992 21 | 19,769402,34.12095,-118.33911 22 | 20,769403,34.12073,-118.33928 23 | 21,769819,34.20584,-118.19803 24 | 22,769405,34.12634,-118.34482 25 | 23,716941,34.05767,-118.21435 26 | 24,717578,34.15478,-118.27076 27 | 25,716960,34.12121,-118.27164 28 | 26,717804,34.09478,-118.47605 29 | 27,767572,34.12967,-118.22871 30 | 28,767573,34.12964,-118.22901 31 | 29,773012,34.08390,-118.22086 32 | 30,773013,34.08374,-118.22076 33 | 31,764424,34.17878,-118.39469 34 | 32,769388,34.11027,-118.33441 35 | 33,716328,34.06664,-118.25397 36 | 34,717819,34.18784,-118.47407 37 | 35,769941,34.20699,-118.20237 38 | 36,760987,34.15359,-118.34043 39 | 37,718204,34.15541,-118.29575 40 | 38,718045,34.06712,-118.23973 41 | 39,769418,34.12659,-118.34465 42 | 40,768066,34.15118,-118.37480 43 | 41,772140,34.16498,-118.47493 44 | 42,773927,34.15262,-118.28034 45 | 43,760024,34.15930,-118.46483 46 | 44,774012,34.14769,-118.20137 47 | 45,774011,34.14747,-118.20123 48 | 46,767609,34.18555,-118.21733 49 | 47,769359,34.15660,-118.42216 50 | 48,760650,34.07505,-118.23256 51 | 49,716956,34.10658,-118.25544 52 | 50,769831,34.20663,-118.20101 53 | 51,761604,34.15407,-118.28711 54 | 52,717495,34.15664,-118.41326 55 | 53,716554,34.15597,-118.26660 56 | 54,773953,34.15522,-118.29344 57 | 55,767470,34.15699,-118.22436 58 | 56,716955,34.09579,-118.24427 59 | 57,764949,34.12881,-118.34684 60 | 58,773954,34.15544,-118.29344 61 | 59,767366,34.21216,-118.47341 62 | 60,769444,34.15555,-118.43908 63 | 61,773939,34.15297,-118.37226 64 | 62,774067,34.15362,-118.28441 65 | 63,769443,34.15574,-118.43931 66 | 64,767750,34.09335,-118.20635 67 | 65,767751,34.09335,-118.20616 68 | 66,767610,34.18555,-118.21766 69 | 67,773880,34.15367,-118.34840 70 | 68,764766,34.13338,-118.35350 71 | 69,717497,34.15685,-118.41456 72 | 70,717490,34.14745,-118.37124 73 | 71,717491,34.14761,-118.37110 74 | 72,717492,34.15459,-118.37935 75 | 73,717493,34.15434,-118.39618 76 | 74,765176,34.13286,-118.35135 77 | 75,717498,34.15571,-118.43273 78 | 76,717499,34.15666,-118.44808 79 | 77,765171,34.16789,-118.46896 80 | 78,718064,34.11296,-118.24489 81 | 79,718066,34.12302,-118.22889 82 | 80,765164,34.07898,-118.28911 83 | 81,769431,34.15843,-118.45664 84 | 82,769430,34.15818,-118.45658 85 | 83,717610,34.17886,-118.39497 86 | 84,767053,34.17091,-118.46775 87 | 85,767621,34.13486,-118.22969 88 | 86,772596,34.17126,-118.50495 89 | 87,772597,34.17109,-118.50495 90 | 88,767350,34.18011,-118.47045 91 | 89,767351,34.18022,-118.47022 92 | 90,716571,34.20000,-118.40337 93 | 91,773023,34.05773,-118.24348 94 | 92,767585,34.16556,-118.22432 95 | 93,773024,34.05759,-118.24357 96 | 94,717483,34.11684,-118.33698 97 | 95,718379,34.14224,-118.27812 98 | 96,717481,34.10634,-118.32826 99 | 97,717480,34.10478,-118.32497 100 | 98,717486,34.12974,-118.34809 101 | 99,764120,34.20164,-118.40366 102 | 100,772151,34.16928,-118.49872 103 | 101,718371,34.09017,-118.23849 104 | 102,717489,34.13876,-118.36438 105 | 103,717488,34.13561,-118.36006 106 | 104,717818,34.17220,-118.46753 107 | 105,718076,34.16339,-118.22530 108 | 106,718072,34.14910,-118.22570 109 | 107,767455,34.14347,-118.22704 110 | 108,767454,34.14352,-118.22733 111 | 109,761599,34.14226,-118.27786 112 | 110,717099,34.15648,-118.24674 113 | 111,773916,34.15247,-118.28520 114 | 112,716968,34.16588,-118.29809 115 | 113,769467,34.15451,-118.39699 116 | 114,717576,34.15559,-118.29570 117 | 115,717573,34.15384,-118.32500 118 | 116,717572,34.15351,-118.32751 119 | 117,717571,34.15326,-118.35921 120 | 118,717570,34.15302,-118.35921 121 | 119,764760,34.13401,-118.35506 122 | 120,718089,34.12769,-118.27372 123 | 121,769847,34.20983,-118.22351 124 | 122,717608,34.17154,-118.38812 125 | 123,767523,34.11439,-118.24209 126 | 124,716942,34.05930,-118.21451 127 | 125,718090,34.14847,-118.27969 128 | 126,769867,34.21846,-118.23931 129 | 127,717472,34.10045,-118.31601 130 | 128,717473,34.10054,-118.31581 131 | 129,759591,34.11521,-118.26825 132 | 130,764781,34.16037,-118.47012 133 | 131,765099,34.16378,-118.47224 134 | 132,762329,34.13904,-118.22862 135 | 133,716953,34.09458,-118.24279 136 | 134,716951,34.08581,-118.23182 137 | 135,767509,34.11059,-118.24819 138 | 136,765182,34.06491,-118.25126 139 | 137,769358,34.15679,-118.42222 140 | 138,772513,34.06871,-118.23661 141 | 139,716958,34.11167,-118.26501 142 | 140,718496,34.15403,-118.34232 143 | 141,769346,34.15677,-118.40424 144 | 142,773904,34.15641,-118.30266 145 | 143,718499,34.15469,-118.31253 146 | 144,764853,34.06461,-118.25102 147 | 145,761003,34.15546,-118.30841 148 | 146,717502,34.16521,-118.47484 149 | 147,759602,34.12199,-118.27178 150 | 148,717504,34.16519,-118.49166 151 | 149,763995,34.21979,-118.40931 152 | 150,717508,34.17112,-118.51814 153 | 151,765265,34.18529,-118.47395 154 | 152,773996,34.14511,-118.21587 155 | 153,773995,34.14483,-118.21587 156 | 154,717469,34.09710,-118.31366 157 | 155,717468,34.09699,-118.31381 158 | 156,764106,34.17169,-118.38801 159 | 157,717465,34.09373,-118.30907 160 | 158,764794,34.16028,-118.46808 161 | 159,717466,34.09359,-118.30918 162 | 160,717461,34.08558,-118.30174 163 | 161,717460,34.08571,-118.30161 164 | 162,717463,34.09004,-118.30590 165 | 163,717462,34.08993,-118.30607 166 | 164,769345,34.15655,-118.40441 167 | 165,716943,34.05987,-118.21492 168 | 166,772669,34.07828,-118.22834 169 | 167,717582,34.15646,-118.26092 170 | 168,717583,34.15627,-118.25506 171 | 169,717580,34.15620,-118.26359 172 | 170,716949,34.08406,-118.22974 173 | 171,717587,34.15402,-118.23893 174 | 172,772178,34.16903,-118.49885 175 | 173,717585,34.15564,-118.24188 176 | 174,716939,34.04301,-118.21724 177 | 175,768469,34.13583,-118.35993 178 | 176,764101,34.16421,-118.38246 179 | 177,767554,34.11966,-118.23143 180 | 178,773975,34.14584,-118.22251 181 | 179,773974,34.14559,-118.22251 182 | 180,717510,34.17128,-118.51976 183 | 181,717513,34.17339,-118.53680 184 | 182,717825,34.22164,-118.47307 185 | 183,767495,34.10377,-118.24992 186 | 184,767494,34.10377,-118.24962 187 | 185,717821,34.20112,-118.47361 188 | 186,717823,34.20264,-118.47326 189 | 187,717458,34.08265,-118.29755 190 | 188,717459,34.08294,-118.29729 191 | 189,769926,34.21356,-118.23113 192 | 190,764858,34.15270,-118.37540 193 | 191,717450,34.07488,-118.27362 194 | 192,717452,34.07502,-118.27356 195 | 193,717453,34.07696,-118.28093 196 | 194,759772,34.17115,-118.30539 197 | 195,717456,34.08102,-118.29325 198 | 196,771673,34.07787,-118.22871 199 | 197,772167,34.16526,-118.47985 200 | 198,769372,34.10270,-118.31730 201 | 199,774204,34.15397,-118.34172 202 | 200,769806,34.19638,-118.18442 203 | 201,717590,34.14929,-118.23182 204 | 202,717592,34.14604,-118.22430 205 | 203,717595,34.14163,-118.18290 206 | 204,772168,34.16542,-118.47985 207 | 205,718141,34.15133,-118.37456 208 | 206,769373,34.10262,-118.31747 -------------------------------------------------------------------------------- /lib/AMSGrad.py: -------------------------------------------------------------------------------- 1 | """AMSGrad for TensorFlow. 2 | From: https://github.com/taki0112/AMSGrad-Tensorflow 3 | """ 4 | 5 | from tensorflow.python.eager import context 6 | from tensorflow.python.framework import ops 7 | from tensorflow.python.ops import control_flow_ops 8 | from tensorflow.python.ops import math_ops 9 | from tensorflow.python.ops import resource_variable_ops 10 | from tensorflow.python.ops import state_ops 11 | from tensorflow.python.ops import variable_scope 12 | from tensorflow.python.training import optimizer 13 | 14 | 15 | class AMSGrad(optimizer.Optimizer): 16 | def __init__(self, learning_rate=0.01, beta1=0.9, beta2=0.99, epsilon=1e-8, use_locking=False, name="AMSGrad"): 17 | super(AMSGrad, self).__init__(use_locking, name) 18 | self._lr = learning_rate 19 | self._beta1 = beta1 20 | self._beta2 = beta2 21 | self._epsilon = epsilon 22 | 23 | self._lr_t = None 24 | self._beta1_t = None 25 | self._beta2_t = None 26 | self._epsilon_t = None 27 | 28 | self._beta1_power = None 29 | self._beta2_power = None 30 | 31 | def _create_slots(self, var_list): 32 | first_var = min(var_list, key=lambda x: x.name) 33 | 34 | create_new = self._beta1_power is None 35 | if not create_new and context.in_graph_mode(): 36 | create_new = (self._beta1_power.graph is not first_var.graph) 37 | 38 | if create_new: 39 | with ops.colocate_with(first_var): 40 | self._beta1_power = variable_scope.variable(self._beta1, name="beta1_power", trainable=False) 41 | self._beta2_power = variable_scope.variable(self._beta2, name="beta2_power", trainable=False) 42 | # Create slots for the first and second moments. 43 | for v in var_list: 44 | self._zeros_slot(v, "m", self._name) 45 | self._zeros_slot(v, "v", self._name) 46 | self._zeros_slot(v, "vhat", self._name) 47 | 48 | def _prepare(self): 49 | self._lr_t = ops.convert_to_tensor(self._lr) 50 | self._beta1_t = ops.convert_to_tensor(self._beta1) 51 | self._beta2_t = ops.convert_to_tensor(self._beta2) 52 | self._epsilon_t = ops.convert_to_tensor(self._epsilon) 53 | 54 | def _apply_dense(self, grad, var): 55 | beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) 56 | beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) 57 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 58 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 59 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 60 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 61 | 62 | lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) 63 | 64 | # m_t = beta1 * m + (1 - beta1) * g_t 65 | m = self.get_slot(var, "m") 66 | m_scaled_g_values = grad * (1 - beta1_t) 67 | m_t = state_ops.assign(m, beta1_t * m + m_scaled_g_values, use_locking=self._use_locking) 68 | 69 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 70 | v = self.get_slot(var, "v") 71 | v_scaled_g_values = (grad * grad) * (1 - beta2_t) 72 | v_t = state_ops.assign(v, beta2_t * v + v_scaled_g_values, use_locking=self._use_locking) 73 | 74 | # amsgrad 75 | vhat = self.get_slot(var, "vhat") 76 | vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat)) 77 | v_sqrt = math_ops.sqrt(vhat_t) 78 | 79 | var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) 80 | return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t]) 81 | 82 | def _resource_apply_dense(self, grad, var): 83 | var = var.handle 84 | beta1_power = math_ops.cast(self._beta1_power, grad.dtype.base_dtype) 85 | beta2_power = math_ops.cast(self._beta2_power, grad.dtype.base_dtype) 86 | lr_t = math_ops.cast(self._lr_t, grad.dtype.base_dtype) 87 | beta1_t = math_ops.cast(self._beta1_t, grad.dtype.base_dtype) 88 | beta2_t = math_ops.cast(self._beta2_t, grad.dtype.base_dtype) 89 | epsilon_t = math_ops.cast(self._epsilon_t, grad.dtype.base_dtype) 90 | 91 | lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) 92 | 93 | # m_t = beta1 * m + (1 - beta1) * g_t 94 | m = self.get_slot(var, "m").handle 95 | m_scaled_g_values = grad * (1 - beta1_t) 96 | m_t = state_ops.assign(m, beta1_t * m + m_scaled_g_values, use_locking=self._use_locking) 97 | 98 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 99 | v = self.get_slot(var, "v").handle 100 | v_scaled_g_values = (grad * grad) * (1 - beta2_t) 101 | v_t = state_ops.assign(v, beta2_t * v + v_scaled_g_values, use_locking=self._use_locking) 102 | 103 | # amsgrad 104 | vhat = self.get_slot(var, "vhat").handle 105 | vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat)) 106 | v_sqrt = math_ops.sqrt(vhat_t) 107 | 108 | var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) 109 | return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t]) 110 | 111 | def _apply_sparse_shared(self, grad, var, indices, scatter_add): 112 | beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) 113 | beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) 114 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 115 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 116 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 117 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 118 | 119 | lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) 120 | 121 | # m_t = beta1 * m + (1 - beta1) * g_t 122 | m = self.get_slot(var, "m") 123 | m_scaled_g_values = grad * (1 - beta1_t) 124 | m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) 125 | with ops.control_dependencies([m_t]): 126 | m_t = scatter_add(m, indices, m_scaled_g_values) 127 | 128 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 129 | v = self.get_slot(var, "v") 130 | v_scaled_g_values = (grad * grad) * (1 - beta2_t) 131 | v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) 132 | with ops.control_dependencies([v_t]): 133 | v_t = scatter_add(v, indices, v_scaled_g_values) 134 | 135 | # amsgrad 136 | vhat = self.get_slot(var, "vhat") 137 | vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat)) 138 | v_sqrt = math_ops.sqrt(vhat_t) 139 | var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) 140 | return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t]) 141 | 142 | def _apply_sparse(self, grad, var): 143 | return self._apply_sparse_shared( 144 | grad.values, var, grad.indices, 145 | lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda 146 | x, i, v, use_locking=self._use_locking)) 147 | 148 | def _resource_scatter_add(self, x, i, v): 149 | with ops.control_dependencies( 150 | [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): 151 | return x.value() 152 | 153 | def _resource_apply_sparse(self, grad, var, indices): 154 | return self._apply_sparse_shared( 155 | grad, var, indices, self._resource_scatter_add) 156 | 157 | def _finish(self, update_ops, name_scope): 158 | # Update the power accumulators. 159 | with ops.control_dependencies(update_ops): 160 | with ops.colocate_with(self._beta1_power): 161 | update_beta1 = self._beta1_power.assign( 162 | self._beta1_power * self._beta1_t, 163 | use_locking=self._use_locking) 164 | update_beta2 = self._beta2_power.assign( 165 | self._beta2_power * self._beta2_t, 166 | use_locking=self._use_locking) 167 | return control_flow_ops.group(*update_ops + [update_beta1, update_beta2], 168 | name=name_scope) 169 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoshangcs/GTS/8ed45ff1476639f78c382ff09ecca8e60523e7ce/lib/__init__.py -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def masked_mse_tf(preds, labels, null_val=np.nan): 6 | """ 7 | Accuracy with masking. 8 | :param preds: 9 | :param labels: 10 | :param null_val: 11 | :return: 12 | """ 13 | if np.isnan(null_val): 14 | mask = ~tf.is_nan(labels) 15 | else: 16 | mask = tf.not_equal(labels, null_val) 17 | mask = tf.cast(mask, tf.float32) 18 | mask /= tf.reduce_mean(mask) 19 | mask = tf.where(tf.is_nan(mask), tf.zeros_like(mask), mask) 20 | loss = tf.square(tf.subtract(preds, labels)) 21 | loss = loss * mask 22 | loss = tf.where(tf.is_nan(loss), tf.zeros_like(loss), loss) 23 | return tf.reduce_mean(loss) 24 | 25 | 26 | def masked_mae_tf(preds, labels, null_val=np.nan): 27 | """ 28 | Accuracy with masking. 29 | :param preds: 30 | :param labels: 31 | :param null_val: 32 | :return: 33 | """ 34 | if np.isnan(null_val): 35 | mask = ~tf.is_nan(labels) 36 | else: 37 | mask = tf.not_equal(labels, null_val) 38 | mask = tf.cast(mask, tf.float32) 39 | mask /= tf.reduce_mean(mask) 40 | mask = tf.where(tf.is_nan(mask), tf.zeros_like(mask), mask) 41 | loss = tf.abs(tf.subtract(preds, labels)) 42 | loss = loss * mask 43 | loss = tf.where(tf.is_nan(loss), tf.zeros_like(loss), loss) 44 | return tf.reduce_mean(loss) 45 | 46 | 47 | def masked_rmse_tf(preds, labels, null_val=np.nan): 48 | """ 49 | Accuracy with masking. 50 | :param preds: 51 | :param labels: 52 | :param null_val: 53 | :return: 54 | """ 55 | return tf.sqrt(masked_mse_tf(preds=preds, labels=labels, null_val=null_val)) 56 | 57 | 58 | def masked_rmse_np(preds, labels, null_val=np.nan): 59 | return np.sqrt(masked_mse_np(preds=preds, labels=labels, null_val=null_val)) 60 | 61 | 62 | def masked_mse_np(preds, labels, null_val=np.nan): 63 | with np.errstate(divide='ignore', invalid='ignore'): 64 | if np.isnan(null_val): 65 | mask = ~np.isnan(labels) 66 | else: 67 | mask = np.not_equal(labels, null_val) 68 | mask = mask.astype('float32') 69 | mask /= np.mean(mask) 70 | rmse = np.square(np.subtract(preds, labels)).astype('float32') 71 | rmse = np.nan_to_num(rmse * mask) 72 | return np.mean(rmse) 73 | 74 | 75 | def masked_mae_np(preds, labels, null_val=np.nan): 76 | with np.errstate(divide='ignore', invalid='ignore'): 77 | if np.isnan(null_val): 78 | mask = ~np.isnan(labels) 79 | else: 80 | mask = np.not_equal(labels, null_val) 81 | mask = mask.astype('float32') 82 | mask /= np.mean(mask) 83 | mae = np.abs(np.subtract(preds, labels)).astype('float32') 84 | mae = np.nan_to_num(mae * mask) 85 | return np.mean(mae) 86 | 87 | 88 | def masked_mape_np(preds, labels, null_val=np.nan): 89 | with np.errstate(divide='ignore', invalid='ignore'): 90 | if np.isnan(null_val): 91 | mask = ~np.isnan(labels) 92 | else: 93 | mask = np.not_equal(labels, null_val) 94 | mask = mask.astype('float32') 95 | mask /= np.mean(mask) 96 | mape = np.abs(np.divide(np.subtract(preds, labels).astype('float32'), labels)) 97 | mape = np.nan_to_num(mask * mape) 98 | return np.mean(mape) 99 | 100 | 101 | # Builds loss function. 102 | def masked_mse_loss(scaler, null_val): 103 | def loss(preds, labels): 104 | if scaler: 105 | preds = scaler.inverse_transform(preds) 106 | labels = scaler.inverse_transform(labels) 107 | return masked_mse_tf(preds=preds, labels=labels, null_val=null_val) 108 | 109 | return loss 110 | 111 | 112 | def masked_rmse_loss(scaler, null_val): 113 | def loss(preds, labels): 114 | if scaler: 115 | preds = scaler.inverse_transform(preds) 116 | labels = scaler.inverse_transform(labels) 117 | return masked_rmse_tf(preds=preds, labels=labels, null_val=null_val) 118 | 119 | return loss 120 | 121 | 122 | def masked_mae_loss(scaler, null_val): 123 | def loss(preds, labels): 124 | if scaler: 125 | preds = scaler.inverse_transform(preds) 126 | labels = scaler.inverse_transform(labels) 127 | mae = masked_mae_tf(preds=preds, labels=labels, null_val=null_val) 128 | return mae 129 | 130 | return loss 131 | 132 | 133 | def calculate_metrics(df_pred, df_test, null_val): 134 | """ 135 | Calculate the MAE, MAPE, RMSE 136 | :param df_pred: 137 | :param df_test: 138 | :param null_val: 139 | :return: 140 | """ 141 | mape = masked_mape_np(preds=df_pred.as_matrix(), labels=df_test.as_matrix(), null_val=null_val) 142 | mae = masked_mae_np(preds=df_pred.as_matrix(), labels=df_test.as_matrix(), null_val=null_val) 143 | rmse = masked_rmse_np(preds=df_pred.as_matrix(), labels=df_test.as_matrix(), null_val=null_val) 144 | return mae, mape, rmse -------------------------------------------------------------------------------- /lib/metrics_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from lib import metrics 7 | 8 | 9 | class MyTestCase(unittest.TestCase): 10 | def test_masked_mape_np(self): 11 | preds = np.array([ 12 | [1, 2, 2], 13 | [3, 4, 5], 14 | ], dtype=np.float32) 15 | labels = np.array([ 16 | [1, 2, 2], 17 | [3, 4, 4] 18 | ], dtype=np.float32) 19 | mape = metrics.masked_mape_np(preds=preds, labels=labels) 20 | self.assertAlmostEqual(1 / 24.0, mape, delta=1e-5) 21 | 22 | def test_masked_mape_np2(self): 23 | preds = np.array([ 24 | [1, 2, 2], 25 | [3, 4, 5], 26 | ], dtype=np.float32) 27 | labels = np.array([ 28 | [1, 2, 2], 29 | [3, 4, 4] 30 | ], dtype=np.float32) 31 | mape = metrics.masked_mape_np(preds=preds, labels=labels, null_val=4) 32 | self.assertEqual(0., mape) 33 | 34 | def test_masked_mape_np_all_zero(self): 35 | preds = np.array([ 36 | [1, 2], 37 | [3, 4], 38 | ], dtype=np.float32) 39 | labels = np.array([ 40 | [0, 0], 41 | [0, 0] 42 | ], dtype=np.float32) 43 | mape = metrics.masked_mape_np(preds=preds, labels=labels, null_val=0) 44 | self.assertEqual(0., mape) 45 | 46 | def test_masked_mape_np_all_nan(self): 47 | preds = np.array([ 48 | [1, 2], 49 | [3, 4], 50 | ], dtype=np.float32) 51 | labels = np.array([ 52 | [np.nan, np.nan], 53 | [np.nan, np.nan] 54 | ], dtype=np.float32) 55 | mape = metrics.masked_mape_np(preds=preds, labels=labels) 56 | self.assertEqual(0., mape) 57 | 58 | def test_masked_mape_np_nan(self): 59 | preds = np.array([ 60 | [1, 2], 61 | [3, 4], 62 | ], dtype=np.float32) 63 | labels = np.array([ 64 | [np.nan, np.nan], 65 | [np.nan, 3] 66 | ], dtype=np.float32) 67 | mape = metrics.masked_mape_np(preds=preds, labels=labels) 68 | self.assertAlmostEqual(1 / 3., mape, delta=1e-5) 69 | 70 | def test_masked_rmse_np_vanilla(self): 71 | preds = np.array([ 72 | [1, 2], 73 | [3, 4], 74 | ], dtype=np.float32) 75 | labels = np.array([ 76 | [1, 4], 77 | [3, 4] 78 | ], dtype=np.float32) 79 | mape = metrics.masked_rmse_np(preds=preds, labels=labels, null_val=0) 80 | self.assertEqual(1., mape) 81 | 82 | def test_masked_rmse_np_nan(self): 83 | preds = np.array([ 84 | [1, 2], 85 | [3, 4], 86 | ], dtype=np.float32) 87 | labels = np.array([ 88 | [1, np.nan], 89 | [3, 4] 90 | ], dtype=np.float32) 91 | rmse = metrics.masked_rmse_np(preds=preds, labels=labels) 92 | self.assertEqual(0., rmse) 93 | 94 | def test_masked_rmse_np_all_zero(self): 95 | preds = np.array([ 96 | [1, 2], 97 | [3, 4], 98 | ], dtype=np.float32) 99 | labels = np.array([ 100 | [0, 0], 101 | [0, 0] 102 | ], dtype=np.float32) 103 | mape = metrics.masked_rmse_np(preds=preds, labels=labels, null_val=0) 104 | self.assertEqual(0., mape) 105 | 106 | def test_masked_rmse_np_missing(self): 107 | preds = np.array([ 108 | [1, 2], 109 | [3, 4], 110 | ], dtype=np.float32) 111 | labels = np.array([ 112 | [1, 0], 113 | [3, 4] 114 | ], dtype=np.float32) 115 | mape = metrics.masked_rmse_np(preds=preds, labels=labels, null_val=0) 116 | self.assertEqual(0., mape) 117 | 118 | def test_masked_rmse_np2(self): 119 | preds = np.array([ 120 | [1, 2], 121 | [3, 4], 122 | ], dtype=np.float32) 123 | labels = np.array([ 124 | [1, 0], 125 | [3, 3] 126 | ], dtype=np.float32) 127 | rmse = metrics.masked_rmse_np(preds=preds, labels=labels, null_val=0) 128 | self.assertAlmostEqual(np.sqrt(1 / 3.), rmse, delta=1e-5) 129 | 130 | 131 | class TFRMSETestCase(unittest.TestCase): 132 | def test_masked_mse_null(self): 133 | with tf.Session() as sess: 134 | preds = tf.constant(np.array([ 135 | [1, 2], 136 | [3, 4], 137 | ], dtype=np.float32)) 138 | labels = tf.constant(np.array([ 139 | [1, 0], 140 | [3, 3] 141 | ], dtype=np.float32)) 142 | rmse = metrics.masked_mse_tf(preds=preds, labels=labels, null_val=0) 143 | self.assertAlmostEqual(1 / 3.0, sess.run(rmse), delta=1e-5) 144 | 145 | def test_masked_mse_vanilla(self): 146 | with tf.Session() as sess: 147 | preds = tf.constant(np.array([ 148 | [1, 2], 149 | [3, 4], 150 | ], dtype=np.float32)) 151 | labels = tf.constant(np.array([ 152 | [1, 0], 153 | [3, 3] 154 | ], dtype=np.float32)) 155 | rmse = metrics.masked_mse_tf(preds=preds, labels=labels) 156 | self.assertAlmostEqual(1.25, sess.run(rmse), delta=1e-5) 157 | 158 | def test_masked_mse_all_zero(self): 159 | with tf.Session() as sess: 160 | preds = tf.constant(np.array([ 161 | [1, 2], 162 | [3, 4], 163 | ], dtype=np.float32)) 164 | labels = tf.constant(np.array([ 165 | [0, 0], 166 | [0, 0] 167 | ], dtype=np.float32)) 168 | rmse = metrics.masked_mse_tf(preds=preds, labels=labels, null_val=0) 169 | self.assertAlmostEqual(0., sess.run(rmse), delta=1e-5) 170 | 171 | def test_masked_mse_nan(self): 172 | with tf.Session() as sess: 173 | preds = tf.constant(np.array([ 174 | [1, 2], 175 | [3, 4], 176 | ], dtype=np.float32)) 177 | labels = tf.constant(np.array([ 178 | [1, 2], 179 | [3, np.nan] 180 | ], dtype=np.float32)) 181 | rmse = metrics.masked_mse_tf(preds=preds, labels=labels) 182 | self.assertAlmostEqual(0., sess.run(rmse), delta=1e-5) 183 | 184 | def test_masked_mse_all_nan(self): 185 | with tf.Session() as sess: 186 | preds = tf.constant(np.array([ 187 | [1, 2], 188 | [3, 4], 189 | ], dtype=np.float32)) 190 | labels = tf.constant(np.array([ 191 | [np.nan, np.nan], 192 | [np.nan, np.nan] 193 | ], dtype=np.float32)) 194 | rmse = metrics.masked_mse_tf(preds=preds, labels=labels, null_val=0) 195 | self.assertAlmostEqual(0., sess.run(rmse), delta=1e-5) 196 | 197 | if __name__ == '__main__': 198 | unittest.main() 199 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import os 4 | import pickle 5 | import scipy.sparse as sp 6 | import sys 7 | import tensorflow as tf 8 | 9 | from scipy.sparse import linalg 10 | 11 | 12 | class DataLoader(object): 13 | def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False): 14 | """ 15 | 16 | :param xs: 17 | :param ys: 18 | :param batch_size: 19 | :param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size. 20 | """ 21 | self.batch_size = batch_size 22 | self.current_ind = 0 23 | if pad_with_last_sample: 24 | num_padding = (batch_size - (len(xs) % batch_size)) % batch_size 25 | x_padding = np.repeat(xs[-1:], num_padding, axis=0) 26 | y_padding = np.repeat(ys[-1:], num_padding, axis=0) 27 | xs = np.concatenate([xs, x_padding], axis=0) 28 | ys = np.concatenate([ys, y_padding], axis=0) 29 | self.size = len(xs) 30 | self.num_batch = int(self.size // self.batch_size) 31 | if shuffle: 32 | permutation = np.random.permutation(self.size) 33 | xs, ys = xs[permutation], ys[permutation] 34 | self.xs = xs 35 | self.ys = ys 36 | 37 | def get_iterator(self): 38 | self.current_ind = 0 39 | 40 | def _wrapper(): 41 | while self.current_ind < self.num_batch: 42 | start_ind = self.batch_size * self.current_ind 43 | end_ind = min(self.size, self.batch_size * (self.current_ind + 1)) 44 | x_i = self.xs[start_ind: end_ind, ...] 45 | y_i = self.ys[start_ind: end_ind, ...] 46 | yield (x_i, y_i) 47 | self.current_ind += 1 48 | 49 | return _wrapper() 50 | 51 | 52 | class StandardScaler: 53 | """ 54 | Standard the input 55 | """ 56 | 57 | def __init__(self, mean, std): 58 | self.mean = mean 59 | self.std = std 60 | 61 | def transform(self, data): 62 | return (data - self.mean) / self.std 63 | 64 | def inverse_transform(self, data): 65 | return (data * self.std) + self.mean 66 | 67 | 68 | def add_simple_summary(writer, names, values, global_step): 69 | """ 70 | Writes summary for a list of scalars. 71 | :param writer: 72 | :param names: 73 | :param values: 74 | :param global_step: 75 | :return: 76 | """ 77 | for name, value in zip(names, values): 78 | summary = tf.Summary() 79 | summary_value = summary.value.add() 80 | summary_value.simple_value = value 81 | summary_value.tag = name 82 | writer.add_summary(summary, global_step) 83 | 84 | 85 | def calculate_normalized_laplacian(adj): 86 | """ 87 | # L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2 88 | # D = diag(A 1) 89 | :param adj: 90 | :return: 91 | """ 92 | adj = sp.coo_matrix(adj) 93 | d = np.array(adj.sum(1)) 94 | d_inv_sqrt = np.power(d, -0.5).flatten() 95 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 96 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 97 | normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 98 | return normalized_laplacian 99 | 100 | 101 | def calculate_random_walk_matrix(adj_mx): 102 | adj_mx = sp.coo_matrix(adj_mx) 103 | d = np.array(adj_mx.sum(1)) 104 | d_inv = np.power(d, -1).flatten() 105 | d_inv[np.isinf(d_inv)] = 0. 106 | d_mat_inv = sp.diags(d_inv) 107 | random_walk_mx = d_mat_inv.dot(adj_mx).tocoo() 108 | return random_walk_mx 109 | 110 | 111 | def calculate_reverse_random_walk_matrix(adj_mx): 112 | return calculate_random_walk_matrix(np.transpose(adj_mx)) 113 | 114 | 115 | def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True): 116 | if undirected: 117 | adj_mx = np.maximum.reduce([adj_mx, adj_mx.T]) 118 | L = calculate_normalized_laplacian(adj_mx) 119 | if lambda_max is None: 120 | lambda_max, _ = linalg.eigsh(L, 1, which='LM') 121 | lambda_max = lambda_max[0] 122 | L = sp.csr_matrix(L) 123 | M, _ = L.shape 124 | I = sp.identity(M, format='csr', dtype=L.dtype) 125 | L = (2 / lambda_max * L) - I 126 | return L.astype(np.float32) 127 | 128 | 129 | def config_logging(log_dir, log_filename='info.log', level=logging.INFO): 130 | # Add file handler and stdout handler 131 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 132 | # Create the log directory if necessary. 133 | try: 134 | os.makedirs(log_dir) 135 | except OSError: 136 | pass 137 | file_handler = logging.FileHandler(os.path.join(log_dir, log_filename)) 138 | file_handler.setFormatter(formatter) 139 | file_handler.setLevel(level=level) 140 | # Add console handler. 141 | console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 142 | console_handler = logging.StreamHandler(sys.stdout) 143 | console_handler.setFormatter(console_formatter) 144 | console_handler.setLevel(level=level) 145 | logging.basicConfig(handlers=[file_handler, console_handler], level=level) 146 | 147 | 148 | def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO): 149 | logger = logging.getLogger(name) 150 | logger.setLevel(level) 151 | # Add file handler and stdout handler 152 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 153 | file_handler = logging.FileHandler(os.path.join(log_dir, log_filename)) 154 | file_handler.setFormatter(formatter) 155 | # Add console handler. 156 | console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 157 | console_handler = logging.StreamHandler(sys.stdout) 158 | console_handler.setFormatter(console_formatter) 159 | logger.addHandler(file_handler) 160 | logger.addHandler(console_handler) 161 | # Add google cloud log handler 162 | logger.info('Log directory: %s', log_dir) 163 | return logger 164 | 165 | 166 | def get_total_trainable_parameter_size(): 167 | """ 168 | Calculates the total number of trainable parameters in the current graph. 169 | :return: 170 | """ 171 | total_parameters = 0 172 | for variable in tf.trainable_variables(): 173 | # shape is an array of tf.Dimension 174 | total_parameters += np.product([x.value for x in variable.get_shape()]) 175 | return total_parameters 176 | 177 | 178 | def load_dataset(dataset_dir, batch_size, test_batch_size=None, **kwargs): 179 | data = {} 180 | for category in ['train', 'val', 'test']: 181 | cat_data = np.load(os.path.join(dataset_dir, category + '.npz')) 182 | data['x_' + category] = cat_data['x'] 183 | data['y_' + category] = cat_data['y'] 184 | scaler = StandardScaler(mean=data['x_train'][..., 0].mean(), std=data['x_train'][..., 0].std()) 185 | # Data format 186 | for category in ['train', 'val', 'test']: 187 | data['x_' + category][..., 0] = scaler.transform(data['x_' + category][..., 0]) 188 | data['y_' + category][..., 0] = scaler.transform(data['y_' + category][..., 0]) 189 | data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True) 190 | data['val_loader'] = DataLoader(data['x_val'], data['y_val'], test_batch_size, shuffle=False) 191 | data['test_loader'] = DataLoader(data['x_test'], data['y_test'], test_batch_size, shuffle=False) 192 | data['scaler'] = scaler 193 | 194 | return data 195 | 196 | ''' 197 | def load_dataset_with_time(dataset_dir, batch_size, test_batch_size=None, **kwargs): 198 | data = {} 199 | for category in ['train', 'val', 'test']: 200 | cat_data = np.load(os.path.join(dataset_dir, category + '.npz')) 201 | data['x_' + category] = cat_data['x'] 202 | data['y_' + category] = cat_data['y'] 203 | data['time_' + category] = cat_data['time'] 204 | scaler = StandardScaler(mean=data['x_train'][..., 0].mean(), std=data['x_train'][..., 0].std()) 205 | # Data format 206 | for category in ['train', 'val', 'test']: 207 | data['x_' + category][..., 0] = scaler.transform(data['x_' + category][..., 0]) 208 | data['y_' + category][..., 0] = scaler.transform(data['y_' + category][..., 0]) 209 | data['train_loader'] = DataLoader(data['x_train'], data['y_train'], data['time_train'], batch_size, shuffle=True) 210 | data['val_loader'] = DataLoader(data['x_val'], data['y_val'], data['time_val'], test_batch_size, shuffle=False) 211 | data['test_loader'] = DataLoader(data['x_test'], data['y_test'], data['time_test'], test_batch_size, shuffle=False) 212 | data['scaler'] = scaler 213 | 214 | return data 215 | ''' 216 | 217 | def load_graph_data(pkl_filename): 218 | sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename) 219 | return sensor_ids, sensor_id_to_ind, adj_mx 220 | 221 | 222 | def load_pickle(pickle_file): 223 | try: 224 | with open(pickle_file, 'rb') as f: 225 | pickle_data = pickle.load(f) 226 | except UnicodeDecodeError as e: 227 | with open(pickle_file, 'rb') as f: 228 | pickle_data = pickle.load(f, encoding='latin1') 229 | except Exception as e: 230 | print('Unable to load data ', pickle_file, ':', e) 231 | raise 232 | return pickle_data 233 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoshangcs/GTS/8ed45ff1476639f78c382ff09ecca8e60523e7ce/model/__init__.py -------------------------------------------------------------------------------- /model/pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoshangcs/GTS/8ed45ff1476639f78c382ff09ecca8e60523e7ce/model/pytorch/__init__.py -------------------------------------------------------------------------------- /model/pytorch/cell.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from lib import utils 4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | 6 | class LayerParams: 7 | def __init__(self, rnn_network: torch.nn.Module, layer_type: str): 8 | self._rnn_network = rnn_network 9 | self._params_dict = {} 10 | self._biases_dict = {} 11 | self._type = layer_type 12 | 13 | def get_weights(self, shape): 14 | if shape not in self._params_dict: 15 | nn_param = torch.nn.Parameter(torch.empty(*shape, device=device)) 16 | torch.nn.init.xavier_normal_(nn_param) 17 | self._params_dict[shape] = nn_param 18 | self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)), 19 | nn_param) 20 | return self._params_dict[shape] 21 | 22 | def get_biases(self, length, bias_start=0.0): 23 | if length not in self._biases_dict: 24 | biases = torch.nn.Parameter(torch.empty(length, device=device)) 25 | torch.nn.init.constant_(biases, bias_start) 26 | self._biases_dict[length] = biases 27 | self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)), 28 | biases) 29 | 30 | return self._biases_dict[length] 31 | 32 | 33 | class DCGRUCell(torch.nn.Module): 34 | def __init__(self, num_units, max_diffusion_step, num_nodes, nonlinearity='tanh', 35 | filter_type="laplacian", use_gc_for_ru=True): 36 | """ 37 | 38 | :param num_units: 39 | :param adj_mx: 40 | :param max_diffusion_step: 41 | :param num_nodes: 42 | :param nonlinearity: 43 | :param filter_type: "laplacian", "random_walk", "dual_random_walk". 44 | :param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates. 45 | """ 46 | 47 | super().__init__() 48 | self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu 49 | # support other nonlinearities up here? 50 | self._num_nodes = num_nodes 51 | self._num_units = num_units 52 | self._max_diffusion_step = max_diffusion_step 53 | self._supports = [] 54 | self._use_gc_for_ru = use_gc_for_ru 55 | 56 | ''' 57 | Option: 58 | if filter_type == "laplacian": 59 | supports.append(utils.calculate_scaled_laplacian(adj_mx, lambda_max=None)) 60 | elif filter_type == "random_walk": 61 | supports.append(utils.calculate_random_walk_matrix(adj_mx).T) 62 | elif filter_type == "dual_random_walk": 63 | supports.append(utils.calculate_random_walk_matrix(adj_mx).T) 64 | supports.append(utils.calculate_random_walk_matrix(adj_mx.T).T) 65 | else: 66 | supports.append(utils.calculate_scaled_laplacian(adj_mx)) 67 | for support in supports: 68 | self._supports.append(self._build_sparse_matrix(support)) 69 | ''' 70 | 71 | self._fc_params = LayerParams(self, 'fc') 72 | self._gconv_params = LayerParams(self, 'gconv') 73 | 74 | @staticmethod 75 | def _build_sparse_matrix(L): 76 | L = L.tocoo() 77 | indices = np.column_stack((L.row, L.col)) 78 | # this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L) 79 | indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))] 80 | L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device) 81 | return L 82 | 83 | def _calculate_random_walk_matrix(self, adj_mx): 84 | 85 | # tf.Print(adj_mx, [adj_mx], message="This is adj: ") 86 | 87 | adj_mx = adj_mx + torch.eye(int(adj_mx.shape[0])).to(device) 88 | d = torch.sum(adj_mx, 1) 89 | d_inv = 1. / d 90 | d_inv = torch.where(torch.isinf(d_inv), torch.zeros(d_inv.shape).to(device), d_inv) 91 | d_mat_inv = torch.diag(d_inv) 92 | random_walk_mx = torch.mm(d_mat_inv, adj_mx) 93 | return random_walk_mx 94 | 95 | def forward(self, inputs, hx, adj): 96 | """Gated recurrent unit (GRU) with Graph Convolution. 97 | :param inputs: (B, num_nodes * input_dim) 98 | :param hx: (B, num_nodes * rnn_units) 99 | 100 | :return 101 | - Output: A `2-D` tensor with shape `(B, num_nodes * rnn_units)`. 102 | """ 103 | adj_mx = self._calculate_random_walk_matrix(adj).t() 104 | output_size = 2 * self._num_units 105 | if self._use_gc_for_ru: 106 | fn = self._gconv 107 | else: 108 | fn = self._fc 109 | value = torch.sigmoid(fn(inputs, adj_mx, hx, output_size, bias_start=1.0)) 110 | value = torch.reshape(value, (-1, self._num_nodes, output_size)) 111 | r, u = torch.split(tensor=value, split_size_or_sections=self._num_units, dim=-1) 112 | r = torch.reshape(r, (-1, self._num_nodes * self._num_units)) 113 | u = torch.reshape(u, (-1, self._num_nodes * self._num_units)) 114 | 115 | c = self._gconv(inputs, adj_mx, r * hx, self._num_units) 116 | if self._activation is not None: 117 | c = self._activation(c) 118 | 119 | new_state = u * hx + (1.0 - u) * c 120 | return new_state 121 | 122 | @staticmethod 123 | def _concat(x, x_): 124 | x_ = x_.unsqueeze(0) 125 | return torch.cat([x, x_], dim=0) 126 | 127 | def _fc(self, inputs, state, output_size, bias_start=0.0): 128 | batch_size = inputs.shape[0] 129 | inputs = torch.reshape(inputs, (batch_size * self._num_nodes, -1)) 130 | state = torch.reshape(state, (batch_size * self._num_nodes, -1)) 131 | inputs_and_state = torch.cat([inputs, state], dim=-1) 132 | input_size = inputs_and_state.shape[-1] 133 | weights = self._fc_params.get_weights((input_size, output_size)) 134 | value = torch.sigmoid(torch.matmul(inputs_and_state, weights)) 135 | biases = self._fc_params.get_biases(output_size, bias_start) 136 | value += biases 137 | return value 138 | 139 | def _gconv(self, inputs, adj_mx, state, output_size, bias_start=0.0): 140 | # Reshape input and state to (batch_size, num_nodes, input_dim/state_dim) 141 | batch_size = inputs.shape[0] 142 | inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1)) 143 | state = torch.reshape(state, (batch_size, self._num_nodes, -1)) 144 | inputs_and_state = torch.cat([inputs, state], dim=2) 145 | input_size = inputs_and_state.size(2) 146 | 147 | x = inputs_and_state 148 | x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size) 149 | x0 = torch.reshape(x0, shape=[self._num_nodes, input_size * batch_size]) 150 | x = torch.unsqueeze(x0, 0) 151 | 152 | if self._max_diffusion_step == 0: 153 | pass 154 | else: 155 | x1 = torch.mm(adj_mx, x0) 156 | x = self._concat(x, x1) 157 | 158 | for k in range(2, self._max_diffusion_step + 1): 159 | x2 = 2 * torch.mm(adj_mx, x1) - x0 160 | x = self._concat(x, x2) 161 | x1, x0 = x2, x1 162 | ''' 163 | Option: 164 | for support in self._supports: 165 | x1 = torch.sparse.mm(support, x0) 166 | x = self._concat(x, x1) 167 | 168 | for k in range(2, self._max_diffusion_step + 1): 169 | x2 = 2 * torch.sparse.mm(support, x1) - x0 170 | x = self._concat(x, x2) 171 | x1, x0 = x2, x1 172 | ''' 173 | num_matrices = self._max_diffusion_step + 1 # Adds for x itself. 174 | x = torch.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size]) 175 | x = x.permute(3, 1, 2, 0) # (batch_size, num_nodes, input_size, order) 176 | x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices]) 177 | 178 | weights = self._gconv_params.get_weights((input_size * num_matrices, output_size)) 179 | x = torch.matmul(x, weights) # (batch_size * self._num_nodes, output_size) 180 | 181 | biases = self._gconv_params.get_biases(output_size, bias_start) 182 | x += biases 183 | # Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim) 184 | return torch.reshape(x, [batch_size, self._num_nodes * output_size]) 185 | -------------------------------------------------------------------------------- /model/pytorch/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def masked_mae_loss(y_pred, y_true): 5 | mask = (y_true != 0).float() 6 | mask /= mask.mean() 7 | loss = torch.abs(y_pred - y_true) 8 | loss = loss * mask 9 | # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3 10 | loss[loss != loss] = 0 11 | return loss.mean() 12 | 13 | def masked_mape_loss(y_pred, y_true): 14 | mask = (y_true != 0).float() 15 | mask /= mask.mean() 16 | loss = torch.abs(torch.div(y_true - y_pred, y_true)) 17 | loss = loss * mask 18 | # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3 19 | loss[loss != loss] = 0 20 | return loss.mean() 21 | 22 | def masked_rmse_loss(y_pred, y_true): 23 | mask = (y_true != 0).float() 24 | mask /= mask.mean() 25 | loss = torch.pow(y_true - y_pred, 2) 26 | loss = loss * mask 27 | # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3 28 | loss[loss != loss] = 0 29 | return torch.sqrt(loss.mean()) 30 | 31 | def masked_mse_loss(y_pred, y_true): 32 | mask = (y_true != 0).float() 33 | mask /= mask.mean() 34 | loss = torch.pow(y_true - y_pred, 2) 35 | loss = loss * mask 36 | # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3 37 | loss[loss != loss] = 0 38 | return loss.mean() 39 | 40 | -------------------------------------------------------------------------------- /model/pytorch/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from model.pytorch.cell import DCGRUCell 5 | import numpy as np 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | def count_parameters(model): 9 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 10 | 11 | def cosine_similarity_torch(x1, x2=None, eps=1e-8): 12 | x2 = x1 if x2 is None else x2 13 | w1 = x1.norm(p=2, dim=1, keepdim=True) 14 | w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True) 15 | return torch.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps) 16 | 17 | def sample_gumbel(shape, eps=1e-20): 18 | U = torch.rand(shape).to(device) 19 | return -torch.autograd.Variable(torch.log(-torch.log(U + eps) + eps)) 20 | 21 | def gumbel_softmax_sample(logits, temperature, eps=1e-10): 22 | sample = sample_gumbel(logits.size(), eps=eps) 23 | y = logits + sample 24 | return F.softmax(y / temperature, dim=-1) 25 | 26 | def gumbel_softmax(logits, temperature, hard=False, eps=1e-10): 27 | """Sample from the Gumbel-Softmax distribution and optionally discretize. 28 | Args: 29 | logits: [batch_size, n_class] unnormalized log-probs 30 | temperature: non-negative scalar 31 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 32 | Returns: 33 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 34 | If hard=True, then the returned sample will be one-hot, otherwise it will 35 | be a probabilitiy distribution that sums to 1 across classes 36 | """ 37 | y_soft = gumbel_softmax_sample(logits, temperature=temperature, eps=eps) 38 | if hard: 39 | shape = logits.size() 40 | _, k = y_soft.data.max(-1) 41 | y_hard = torch.zeros(*shape).to(device) 42 | y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0) 43 | y = torch.autograd.Variable(y_hard - y_soft.data) + y_soft 44 | else: 45 | y = y_soft 46 | return y 47 | 48 | class Seq2SeqAttrs: 49 | def __init__(self, **model_kwargs): 50 | #self.adj_mx = adj_mx 51 | self.max_diffusion_step = int(model_kwargs.get('max_diffusion_step', 2)) 52 | self.cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000)) 53 | self.filter_type = model_kwargs.get('filter_type', 'laplacian') 54 | self.num_nodes = int(model_kwargs.get('num_nodes', 1)) 55 | self.num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1)) 56 | self.rnn_units = int(model_kwargs.get('rnn_units')) 57 | self.hidden_state_size = self.num_nodes * self.rnn_units 58 | 59 | 60 | class EncoderModel(nn.Module, Seq2SeqAttrs): 61 | def __init__(self, **model_kwargs): 62 | nn.Module.__init__(self) 63 | Seq2SeqAttrs.__init__(self, **model_kwargs) 64 | self.input_dim = int(model_kwargs.get('input_dim', 1)) 65 | self.seq_len = int(model_kwargs.get('seq_len')) # for the encoder 66 | self.dcgru_layers = nn.ModuleList( 67 | [DCGRUCell(self.rnn_units, self.max_diffusion_step, self.num_nodes, 68 | filter_type=self.filter_type) for _ in range(self.num_rnn_layers)]) 69 | 70 | def forward(self, inputs, adj, hidden_state=None): 71 | """ 72 | Encoder forward pass. 73 | :param inputs: shape (batch_size, self.num_nodes * self.input_dim) 74 | :param hidden_state: (num_layers, batch_size, self.hidden_state_size) 75 | optional, zeros if not provided 76 | :return: output: # shape (batch_size, self.hidden_state_size) 77 | hidden_state # shape (num_layers, batch_size, self.hidden_state_size) 78 | (lower indices mean lower layers) 79 | """ 80 | batch_size, _ = inputs.size() 81 | if hidden_state is None: 82 | hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size), 83 | device=device) 84 | hidden_states = [] 85 | output = inputs 86 | for layer_num, dcgru_layer in enumerate(self.dcgru_layers): 87 | next_hidden_state = dcgru_layer(output, hidden_state[layer_num], adj) 88 | hidden_states.append(next_hidden_state) 89 | output = next_hidden_state 90 | 91 | return output, torch.stack(hidden_states) # runs in O(num_layers) so not too slow 92 | 93 | 94 | class DecoderModel(nn.Module, Seq2SeqAttrs): 95 | def __init__(self, **model_kwargs): 96 | # super().__init__(is_training, adj_mx, **model_kwargs) 97 | nn.Module.__init__(self) 98 | Seq2SeqAttrs.__init__(self, **model_kwargs) 99 | self.output_dim = int(model_kwargs.get('output_dim', 1)) 100 | self.horizon = int(model_kwargs.get('horizon', 1)) # for the decoder 101 | self.projection_layer = nn.Linear(self.rnn_units, self.output_dim) 102 | self.dcgru_layers = nn.ModuleList( 103 | [DCGRUCell(self.rnn_units, self.max_diffusion_step, self.num_nodes, 104 | filter_type=self.filter_type) for _ in range(self.num_rnn_layers)]) 105 | 106 | def forward(self, inputs, adj, hidden_state=None): 107 | """ 108 | :param inputs: shape (batch_size, self.num_nodes * self.output_dim) 109 | :param hidden_state: (num_layers, batch_size, self.hidden_state_size) 110 | optional, zeros if not provided 111 | :return: output: # shape (batch_size, self.num_nodes * self.output_dim) 112 | hidden_state # shape (num_layers, batch_size, self.hidden_state_size) 113 | (lower indices mean lower layers) 114 | """ 115 | hidden_states = [] 116 | output = inputs 117 | for layer_num, dcgru_layer in enumerate(self.dcgru_layers): 118 | next_hidden_state = dcgru_layer(output, hidden_state[layer_num], adj) 119 | hidden_states.append(next_hidden_state) 120 | output = next_hidden_state 121 | 122 | projected = self.projection_layer(output.view(-1, self.rnn_units)) 123 | output = projected.view(-1, self.num_nodes * self.output_dim) 124 | 125 | return output, torch.stack(hidden_states) 126 | 127 | 128 | class GTSModel(nn.Module, Seq2SeqAttrs): 129 | def __init__(self, temperature, logger, **model_kwargs): 130 | super().__init__() 131 | Seq2SeqAttrs.__init__(self, **model_kwargs) 132 | self.encoder_model = EncoderModel(**model_kwargs) 133 | self.decoder_model = DecoderModel(**model_kwargs) 134 | self.cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000)) 135 | self.use_curriculum_learning = bool(model_kwargs.get('use_curriculum_learning', False)) 136 | self._logger = logger 137 | self.temperature = temperature 138 | self.dim_fc = int(model_kwargs.get('dim_fc', False)) 139 | self.embedding_dim = 100 140 | self.conv1 = torch.nn.Conv1d(1, 8, 10, stride=1) # .to(device) 141 | self.conv2 = torch.nn.Conv1d(8, 16, 10, stride=1) # .to(device) 142 | self.hidden_drop = torch.nn.Dropout(0.2) 143 | self.fc = torch.nn.Linear(self.dim_fc, self.embedding_dim) 144 | self.bn1 = torch.nn.BatchNorm1d(8) 145 | self.bn2 = torch.nn.BatchNorm1d(16) 146 | self.bn3 = torch.nn.BatchNorm1d(self.embedding_dim) 147 | self.fc_out = nn.Linear(self.embedding_dim * 2, self.embedding_dim) 148 | self.fc_cat = nn.Linear(self.embedding_dim, 2) 149 | def encode_onehot(labels): 150 | classes = set(labels) 151 | classes_dict = {c: np.identity(len(classes))[i, :] for i, c in 152 | enumerate(classes)} 153 | labels_onehot = np.array(list(map(classes_dict.get, labels)), 154 | dtype=np.int32) 155 | return labels_onehot 156 | # Generate off-diagonal interaction graph 157 | off_diag = np.ones([self.num_nodes, self.num_nodes]) 158 | rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32) 159 | rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32) 160 | self.rel_rec = torch.FloatTensor(rel_rec).to(device) 161 | self.rel_send = torch.FloatTensor(rel_send).to(device) 162 | 163 | 164 | def _compute_sampling_threshold(self, batches_seen): 165 | return self.cl_decay_steps / ( 166 | self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps)) 167 | 168 | def encoder(self, inputs, adj): 169 | """ 170 | Encoder forward pass 171 | :param inputs: shape (seq_len, batch_size, num_sensor * input_dim) 172 | :return: encoder_hidden_state: (num_layers, batch_size, self.hidden_state_size) 173 | """ 174 | encoder_hidden_state = None 175 | for t in range(self.encoder_model.seq_len): 176 | _, encoder_hidden_state = self.encoder_model(inputs[t], adj, encoder_hidden_state) 177 | 178 | return encoder_hidden_state 179 | 180 | def decoder(self, encoder_hidden_state, adj, labels=None, batches_seen=None): 181 | """ 182 | Decoder forward pass 183 | :param encoder_hidden_state: (num_layers, batch_size, self.hidden_state_size) 184 | :param labels: (self.horizon, batch_size, self.num_nodes * self.output_dim) [optional, not exist for inference] 185 | :param batches_seen: global step [optional, not exist for inference] 186 | :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) 187 | """ 188 | batch_size = encoder_hidden_state.size(1) 189 | go_symbol = torch.zeros((batch_size, self.num_nodes * self.decoder_model.output_dim), 190 | device=device) 191 | decoder_hidden_state = encoder_hidden_state 192 | decoder_input = go_symbol 193 | 194 | outputs = [] 195 | 196 | for t in range(self.decoder_model.horizon): 197 | decoder_output, decoder_hidden_state = self.decoder_model(decoder_input, adj, 198 | decoder_hidden_state) 199 | decoder_input = decoder_output 200 | outputs.append(decoder_output) 201 | if self.training and self.use_curriculum_learning: 202 | c = np.random.uniform(0, 1) 203 | if c < self._compute_sampling_threshold(batches_seen): 204 | decoder_input = labels[t] 205 | outputs = torch.stack(outputs) 206 | return outputs 207 | 208 | def forward(self, label, inputs, node_feas, temp, gumbel_soft, labels=None, batches_seen=None): 209 | """ 210 | :param inputs: shape (seq_len, batch_size, num_sensor * input_dim) 211 | :param labels: shape (horizon, batch_size, num_sensor * output) 212 | :param batches_seen: batches seen till now 213 | :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) 214 | """ 215 | x = node_feas.transpose(1, 0).view(self.num_nodes, 1, -1) 216 | x = self.conv1(x) 217 | x = F.relu(x) 218 | x = self.bn1(x) 219 | # x = self.hidden_drop(x) 220 | x = self.conv2(x) 221 | x = F.relu(x) 222 | x = self.bn2(x) 223 | x = x.view(self.num_nodes, -1) 224 | x = self.fc(x) 225 | x = F.relu(x) 226 | x = self.bn3(x) 227 | 228 | receivers = torch.matmul(self.rel_rec, x) 229 | senders = torch.matmul(self.rel_send, x) 230 | x = torch.cat([senders, receivers], dim=1) 231 | x = torch.relu(self.fc_out(x)) 232 | x = self.fc_cat(x) 233 | 234 | adj = gumbel_softmax(x, temperature=temp, hard=True) 235 | adj = adj[:, 0].clone().reshape(self.num_nodes, -1) 236 | # mask = torch.eye(self.num_nodes, self.num_nodes).to(device).byte() 237 | mask = torch.eye(self.num_nodes, self.num_nodes).bool().to(device) 238 | adj.masked_fill_(mask, 0) 239 | 240 | encoder_hidden_state = self.encoder(inputs, adj) 241 | self._logger.debug("Encoder complete, starting decoder") 242 | outputs = self.decoder(encoder_hidden_state, adj, labels, batches_seen=batches_seen) 243 | self._logger.debug("Decoder complete") 244 | if batches_seen == 0: 245 | self._logger.info( 246 | "Total trainable parameters {}".format(count_parameters(self)) 247 | ) 248 | 249 | return outputs, x.softmax(-1)[:, 0].clone().reshape(self.num_nodes, -1) 250 | -------------------------------------------------------------------------------- /model/pytorch/supervisor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.tensorboard import SummaryWriter 3 | import numpy as np 4 | from lib import utils 5 | from model.pytorch.model import GTSModel 6 | from model.pytorch.loss import masked_mae_loss, masked_mape_loss, masked_rmse_loss, masked_mse_loss 7 | import pandas as pd 8 | import os 9 | import time 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | class GTSSupervisor: 14 | def __init__(self, save_adj_name, temperature, **kwargs): 15 | self._kwargs = kwargs 16 | self._data_kwargs = kwargs.get('data') 17 | self._model_kwargs = kwargs.get('model') 18 | self._train_kwargs = kwargs.get('train') 19 | self.temperature = float(temperature) 20 | self.opt = self._train_kwargs.get('optimizer') 21 | self.max_grad_norm = self._train_kwargs.get('max_grad_norm', 1.) 22 | self.ANNEAL_RATE = 0.00003 23 | self.temp_min = 0.1 24 | self.save_adj_name = save_adj_name 25 | self.epoch_use_regularization = self._train_kwargs.get('epoch_use_regularization') 26 | self.num_sample = self._train_kwargs.get('num_sample') 27 | 28 | # logging. 29 | self._log_dir = self._get_log_dir(kwargs) 30 | self._writer = SummaryWriter('runs/' + self._log_dir) 31 | log_level = self._kwargs.get('log_level', 'INFO') 32 | self._logger = utils.get_logger(self._log_dir, __name__, 'info.log', level=log_level) 33 | 34 | # data set 35 | self._data = utils.load_dataset(**self._data_kwargs) 36 | self.standard_scaler = self._data['scaler'] 37 | 38 | ### Feas 39 | if self._data_kwargs['dataset_dir'] == 'data/METR-LA': 40 | df = pd.read_hdf('./data/metr-la.h5') 41 | elif self._data_kwargs['dataset_dir'] == 'data/PEMS-BAY': 42 | df = pd.read_hdf('./data/pems-bay.h5') 43 | #else: 44 | # df = pd.read_csv('./data/pmu_normalized.csv', header=None) 45 | # df = df.transpose() 46 | num_samples = df.shape[0] 47 | num_train = round(num_samples * 0.7) 48 | df = df[:num_train].values 49 | scaler = utils.StandardScaler(mean=df.mean(), std=df.std()) 50 | train_feas = scaler.transform(df) 51 | self._train_feas = torch.Tensor(train_feas).to(device) 52 | #print(self._train_feas.shape) 53 | 54 | k = self._train_kwargs.get('knn_k') 55 | knn_metric = 'cosine' 56 | from sklearn.neighbors import kneighbors_graph 57 | g = kneighbors_graph(train_feas.T, k, metric=knn_metric) 58 | g = np.array(g.todense(), dtype=np.float32) 59 | self.adj_mx = torch.Tensor(g).to(device) 60 | self.num_nodes = int(self._model_kwargs.get('num_nodes', 1)) 61 | self.input_dim = int(self._model_kwargs.get('input_dim', 1)) 62 | self.seq_len = int(self._model_kwargs.get('seq_len')) # for the encoder 63 | self.output_dim = int(self._model_kwargs.get('output_dim', 1)) 64 | self.use_curriculum_learning = bool( 65 | self._model_kwargs.get('use_curriculum_learning', False)) 66 | self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder 67 | 68 | # setup model 69 | GTS_model = GTSModel(self.temperature, self._logger, **self._model_kwargs) 70 | self.GTS_model = GTS_model.cuda() if torch.cuda.is_available() else GTS_model 71 | self._logger.info("Model created") 72 | 73 | self._epoch_num = self._train_kwargs.get('epoch', 0) 74 | if self._epoch_num > 0: 75 | self.load_model() 76 | 77 | @staticmethod 78 | def _get_log_dir(kwargs): 79 | log_dir = kwargs['train'].get('log_dir') 80 | if log_dir is None: 81 | batch_size = kwargs['data'].get('batch_size') 82 | learning_rate = kwargs['train'].get('base_lr') 83 | max_diffusion_step = kwargs['model'].get('max_diffusion_step') 84 | num_rnn_layers = kwargs['model'].get('num_rnn_layers') 85 | rnn_units = kwargs['model'].get('rnn_units') 86 | structure = '-'.join( 87 | ['%d' % rnn_units for _ in range(num_rnn_layers)]) 88 | horizon = kwargs['model'].get('horizon') 89 | filter_type = kwargs['model'].get('filter_type') 90 | filter_type_abbr = 'L' 91 | if filter_type == 'random_walk': 92 | filter_type_abbr = 'R' 93 | elif filter_type == 'dual_random_walk': 94 | filter_type_abbr = 'DR' 95 | run_id = 'GTS_%s_%d_h_%d_%s_lr_%g_bs_%d_%s/' % ( 96 | filter_type_abbr, max_diffusion_step, horizon, 97 | structure, learning_rate, batch_size, 98 | time.strftime('%m%d%H%M%S')) 99 | base_dir = kwargs.get('base_dir') 100 | log_dir = os.path.join(base_dir, run_id) 101 | if not os.path.exists(log_dir): 102 | os.makedirs(log_dir) 103 | return log_dir 104 | 105 | def save_model(self, epoch): 106 | if not os.path.exists('models/'): 107 | os.makedirs('models/') 108 | 109 | config = dict(self._kwargs) 110 | config['model_state_dict'] = self.GTS_model.state_dict() 111 | config['epoch'] = epoch 112 | torch.save(config, 'models/epo%d.tar' % epoch) 113 | self._logger.info("Saved model at {}".format(epoch)) 114 | return 'models/epo%d.tar' % epoch 115 | 116 | def load_model(self): 117 | self._setup_graph() 118 | assert os.path.exists('models/epo%d.tar' % self._epoch_num), 'Weights at epoch %d not found' % self._epoch_num 119 | checkpoint = torch.load('models/epo%d.tar' % self._epoch_num, map_location='cpu') 120 | self.GTS_model.load_state_dict(checkpoint['model_state_dict']) 121 | self._logger.info("Loaded model at {}".format(self._epoch_num)) 122 | 123 | def _setup_graph(self): 124 | with torch.no_grad(): 125 | self.GTS_model = self.GTS_model.eval() 126 | 127 | val_iterator = self._data['val_loader'].get_iterator() 128 | 129 | for _, (x, y) in enumerate(val_iterator): 130 | x, y = self._prepare_data(x, y) 131 | output = self.GTS_model(x, self._train_feas) 132 | break 133 | 134 | def train(self, **kwargs): 135 | kwargs.update(self._train_kwargs) 136 | return self._train(**kwargs) 137 | 138 | def evaluate(self,label, dataset='val', batches_seen=0, gumbel_soft=True): 139 | """ 140 | Computes mean L1Loss 141 | :return: mean L1Loss 142 | """ 143 | with torch.no_grad(): 144 | self.GTS_model = self.GTS_model.eval() 145 | 146 | val_iterator = self._data['{}_loader'.format(dataset)].get_iterator() 147 | losses = [] 148 | mapes = [] 149 | #rmses = [] 150 | mses = [] 151 | temp = self.temperature 152 | 153 | l_3 = [] 154 | m_3 = [] 155 | r_3 = [] 156 | l_6 = [] 157 | m_6 = [] 158 | r_6 = [] 159 | l_12 = [] 160 | m_12 = [] 161 | r_12 = [] 162 | 163 | for batch_idx, (x, y) in enumerate(val_iterator): 164 | x, y = self._prepare_data(x, y) 165 | 166 | output, mid_output = self.GTS_model(label, x, self._train_feas, temp, gumbel_soft) 167 | 168 | if label == 'without_regularization': 169 | loss = self._compute_loss(y, output) 170 | y_true = self.standard_scaler.inverse_transform(y) 171 | y_pred = self.standard_scaler.inverse_transform(output) 172 | mapes.append(masked_mape_loss(y_pred, y_true).item()) 173 | mses.append(masked_mse_loss(y_pred, y_true).item()) 174 | #rmses.append(masked_rmse_loss(y_pred, y_true).item()) 175 | losses.append(loss.item()) 176 | 177 | 178 | # Followed the DCRNN TensorFlow Implementation 179 | l_3.append(masked_mae_loss(y_pred[2:3], y_true[2:3]).item()) 180 | m_3.append(masked_mape_loss(y_pred[2:3], y_true[2:3]).item()) 181 | r_3.append(masked_mse_loss(y_pred[2:3], y_true[2:3]).item()) 182 | l_6.append(masked_mae_loss(y_pred[5:6], y_true[5:6]).item()) 183 | m_6.append(masked_mape_loss(y_pred[5:6], y_true[5:6]).item()) 184 | r_6.append(masked_mse_loss(y_pred[5:6], y_true[5:6]).item()) 185 | l_12.append(masked_mae_loss(y_pred[11:12], y_true[11:12]).item()) 186 | m_12.append(masked_mape_loss(y_pred[11:12], y_true[11:12]).item()) 187 | r_12.append(masked_mse_loss(y_pred[11:12], y_true[11:12]).item()) 188 | 189 | 190 | else: 191 | loss_1 = self._compute_loss(y, output) 192 | pred = torch.sigmoid(mid_output.view(mid_output.shape[0] * mid_output.shape[1])) 193 | true_label = self.adj_mx.view(mid_output.shape[0] * mid_output.shape[1]).to(device) 194 | compute_loss = torch.nn.BCELoss() 195 | loss_g = compute_loss(pred, true_label) 196 | loss = loss_1 + loss_g 197 | # option 198 | # loss = loss_1 + 10*loss_g 199 | losses.append((loss_1.item()+loss_g.item())) 200 | 201 | y_true = self.standard_scaler.inverse_transform(y) 202 | y_pred = self.standard_scaler.inverse_transform(output) 203 | mapes.append(masked_mape_loss(y_pred, y_true).item()) 204 | #rmses.append(masked_rmse_loss(y_pred, y_true).item()) 205 | mses.append(masked_mse_loss(y_pred, y_true).item()) 206 | 207 | # Followed the DCRNN TensorFlow Implementation 208 | l_3.append(masked_mae_loss(y_pred[2:3], y_true[2:3]).item()) 209 | m_3.append(masked_mape_loss(y_pred[2:3], y_true[2:3]).item()) 210 | r_3.append(masked_mse_loss(y_pred[2:3], y_true[2:3]).item()) 211 | l_6.append(masked_mae_loss(y_pred[5:6], y_true[5:6]).item()) 212 | m_6.append(masked_mape_loss(y_pred[5:6], y_true[5:6]).item()) 213 | r_6.append(masked_mse_loss(y_pred[5:6], y_true[5:6]).item()) 214 | l_12.append(masked_mae_loss(y_pred[11:12], y_true[11:12]).item()) 215 | m_12.append(masked_mape_loss(y_pred[11:12], y_true[11:12]).item()) 216 | r_12.append(masked_mse_loss(y_pred[11:12], y_true[11:12]).item()) 217 | 218 | #if batch_idx % 100 == 1: 219 | # temp = np.maximum(temp * np.exp(-self.ANNEAL_RATE * batch_idx), self.temp_min) 220 | mean_loss = np.mean(losses) 221 | mean_mape = np.mean(mapes) 222 | mean_rmse = np.sqrt(np.mean(mses)) 223 | # mean_rmse = np.mean(rmses) #another option 224 | 225 | if dataset == 'test': 226 | 227 | # Followed the DCRNN PyTorch Implementation 228 | message = 'Test: mae: {:.4f}, mape: {:.4f}, rmse: {:.4f}'.format(mean_loss, mean_mape, mean_rmse) 229 | self._logger.info(message) 230 | 231 | # Followed the DCRNN TensorFlow Implementation 232 | message = 'Horizon 15mins: mae: {:.4f}, mape: {:.4f}, rmse: {:.4f}'.format(np.mean(l_3), np.mean(m_3), 233 | np.sqrt(np.mean(r_3))) 234 | self._logger.info(message) 235 | message = 'Horizon 30mins: mae: {:.4f}, mape: {:.4f}, rmse: {:.4f}'.format(np.mean(l_6), np.mean(m_6), 236 | np.sqrt(np.mean(r_6))) 237 | self._logger.info(message) 238 | message = 'Horizon 60mins: mae: {:.4f}, mape: {:.4f}, rmse: {:.4f}'.format(np.mean(l_12), np.mean(m_12), 239 | np.sqrt(np.mean(r_12))) 240 | self._logger.info(message) 241 | 242 | self._writer.add_scalar('{} loss'.format(dataset), mean_loss, batches_seen) 243 | if label == 'without_regularization': 244 | return mean_loss, mean_mape, mean_rmse 245 | else: 246 | return mean_loss 247 | 248 | 249 | def _train(self, base_lr, 250 | steps, patience=200, epochs=100, lr_decay_ratio=0.1, log_every=1, save_model=0, 251 | test_every_n_epochs=10, epsilon=1e-8, **kwargs): 252 | # steps is used in learning rate - will see if need to use it? 253 | min_val_loss = float('inf') 254 | wait = 0 255 | if self.opt == 'adam': 256 | optimizer = torch.optim.Adam(self.GTS_model.parameters(), lr=base_lr, eps=epsilon) 257 | elif self.opt == 'sgd': 258 | optimizer = torch.optim.SGD(self.GTS_model.parameters(), lr=base_lr) 259 | else: 260 | optimizer = torch.optim.Adam(self.GTS_model.parameters(), lr=base_lr, eps=epsilon) 261 | 262 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps, gamma=float(lr_decay_ratio)) 263 | 264 | self._logger.info('Start training ...') 265 | 266 | # this will fail if model is loaded with a changed batch_size 267 | num_batches = self._data['train_loader'].num_batch 268 | self._logger.info("num_batches:{}".format(num_batches)) 269 | 270 | batches_seen = num_batches * self._epoch_num 271 | 272 | for epoch_num in range(self._epoch_num, epochs): 273 | print("Num of epoch:",epoch_num) 274 | self.GTS_model = self.GTS_model.train() 275 | train_iterator = self._data['train_loader'].get_iterator() 276 | losses = [] 277 | start_time = time.time() 278 | temp = self.temperature 279 | gumbel_soft = True 280 | 281 | if epoch_num < self.epoch_use_regularization: 282 | label = 'with_regularization' 283 | else: 284 | label = 'without_regularization' 285 | 286 | for batch_idx, (x, y) in enumerate(train_iterator): 287 | optimizer.zero_grad() 288 | x, y = self._prepare_data(x, y) 289 | output, mid_output = self.GTS_model(label, x, self._train_feas, temp, gumbel_soft, y, batches_seen) 290 | if (epoch_num % epochs) == epochs - 1: 291 | output, mid_output = self.GTS_model(label, x, self._train_feas, temp, gumbel_soft, y, batches_seen) 292 | 293 | if batches_seen == 0: 294 | if self.opt == 'adam': 295 | optimizer = torch.optim.Adam(self.GTS_model.parameters(), lr=base_lr, eps=epsilon) 296 | elif self.opt == 'sgd': 297 | optimizer = torch.optim.SGD(self.GTS_model.parameters(), lr=base_lr) 298 | else: 299 | optimizer = torch.optim.Adam(self.GTS_model.parameters(), lr=base_lr, eps=epsilon) 300 | 301 | self.GTS_model.to(device) 302 | 303 | #if batch_idx % 100 == 1: 304 | # temp = np.maximum(temp * np.exp(-self.ANNEAL_RATE * batch_idx), self.temp_min) 305 | 306 | if label == 'without_regularization': # or label == 'predictor': 307 | loss = self._compute_loss(y, output) 308 | losses.append(loss.item()) 309 | else: 310 | loss_1 = self._compute_loss(y, output) 311 | pred = mid_output.view(mid_output.shape[0] * mid_output.shape[1]) 312 | true_label = self.adj_mx.view(mid_output.shape[0] * mid_output.shape[1]).to(device) 313 | compute_loss = torch.nn.BCELoss() 314 | loss_g = compute_loss(pred, true_label) 315 | loss = loss_1 + loss_g 316 | # option 317 | # loss = loss_1 + 10*loss_g 318 | losses.append((loss_1.item()+loss_g.item())) 319 | 320 | self._logger.debug(loss.item()) 321 | batches_seen += 1 322 | loss.backward() 323 | 324 | # gradient clipping - this does it in place 325 | torch.nn.utils.clip_grad_norm_(self.GTS_model.parameters(), self.max_grad_norm) 326 | 327 | optimizer.step() 328 | self._logger.info("epoch complete") 329 | lr_scheduler.step() 330 | self._logger.info("evaluating now!") 331 | end_time = time.time() 332 | 333 | if label == 'without_regularization': 334 | val_loss, val_mape, val_rmse = self.evaluate(label, dataset='val', batches_seen=batches_seen, gumbel_soft=gumbel_soft) 335 | end_time2 = time.time() 336 | self._writer.add_scalar('training loss', 337 | np.mean(losses), 338 | batches_seen) 339 | 340 | if (epoch_num % log_every) == log_every - 1: 341 | message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, val_mape: {:.4f}, val_rmse: {:.4f}, lr: {:.6f}, ' \ 342 | '{:.1f}s, {:.1f}s'.format(epoch_num, epochs, batches_seen, 343 | np.mean(losses), val_loss, val_mape, val_rmse, 344 | lr_scheduler.get_lr()[0], 345 | (end_time - start_time), (end_time2 - start_time)) 346 | self._logger.info(message) 347 | 348 | if (epoch_num % test_every_n_epochs) == test_every_n_epochs - 1: 349 | test_loss, test_mape, test_rmse = self.evaluate(label, dataset='test', batches_seen=batches_seen, gumbel_soft=gumbel_soft) 350 | message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, test_mape: {:.4f}, test_rmse: {:.4f}, lr: {:.6f}, ' \ 351 | '{:.1f}s, {:.1f}s'.format(epoch_num, epochs, batches_seen, 352 | np.mean(losses), test_loss, test_mape, test_rmse, 353 | lr_scheduler.get_lr()[0], 354 | (end_time - start_time), (end_time2 - start_time)) 355 | self._logger.info(message) 356 | else: 357 | val_loss = self.evaluate(label, dataset='val', batches_seen=batches_seen, gumbel_soft=gumbel_soft) 358 | 359 | end_time2 = time.time() 360 | 361 | self._writer.add_scalar('training loss', np.mean(losses), batches_seen) 362 | 363 | if (epoch_num % log_every) == log_every - 1: 364 | message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}'.format(epoch_num, epochs, 365 | batches_seen, 366 | np.mean(losses), val_loss) 367 | self._logger.info(message) 368 | if (epoch_num % test_every_n_epochs) == test_every_n_epochs - 1: 369 | test_loss = self.evaluate(label, dataset='test', batches_seen=batches_seen, gumbel_soft=gumbel_soft) 370 | message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f}, ' \ 371 | '{:.1f}s, {:.1f}s'.format(epoch_num, epochs, batches_seen, 372 | np.mean(losses), test_loss, lr_scheduler.get_lr()[0], 373 | (end_time - start_time), (end_time2 - start_time)) 374 | self._logger.info(message) 375 | 376 | if val_loss < min_val_loss: 377 | wait = 0 378 | if save_model: 379 | model_file_name = self.save_model(epoch_num) 380 | self._logger.info( 381 | 'Val loss decrease from {:.4f} to {:.4f}, ' 382 | 'saving to {}'.format(min_val_loss, val_loss, model_file_name)) 383 | min_val_loss = val_loss 384 | 385 | elif val_loss >= min_val_loss: 386 | wait += 1 387 | if wait == patience: 388 | self._logger.warning('Early stopping at epoch: %d' % epoch_num) 389 | break 390 | 391 | def _prepare_data(self, x, y): 392 | x, y = self._get_x_y(x, y) 393 | x, y = self._get_x_y_in_correct_dims(x, y) 394 | return x.to(device), y.to(device) 395 | 396 | def _get_x_y(self, x, y): 397 | """ 398 | :param x: shape (batch_size, seq_len, num_sensor, input_dim) 399 | :param y: shape (batch_size, horizon, num_sensor, input_dim) 400 | :returns x shape (seq_len, batch_size, num_sensor, input_dim) 401 | y shape (horizon, batch_size, num_sensor, input_dim) 402 | """ 403 | x = torch.from_numpy(x).float() 404 | y = torch.from_numpy(y).float() 405 | self._logger.debug("X: {}".format(x.size())) 406 | self._logger.debug("y: {}".format(y.size())) 407 | x = x.permute(1, 0, 2, 3) 408 | y = y.permute(1, 0, 2, 3) 409 | return x, y 410 | 411 | def _get_x_y_in_correct_dims(self, x, y): 412 | """ 413 | :param x: shape (seq_len, batch_size, num_sensor, input_dim) 414 | :param y: shape (horizon, batch_size, num_sensor, input_dim) 415 | :return: x: shape (seq_len, batch_size, num_sensor * input_dim) 416 | y: shape (horizon, batch_size, num_sensor * output_dim) 417 | """ 418 | batch_size = x.size(1) 419 | x = x.view(self.seq_len, batch_size, self.num_nodes * self.input_dim) 420 | y = y[..., :self.output_dim].view(self.horizon, batch_size, 421 | self.num_nodes * self.output_dim) 422 | return x, y 423 | 424 | def _compute_loss(self, y_true, y_predicted): 425 | y_true = self.standard_scaler.inverse_transform(y_true) 426 | y_predicted = self.standard_scaler.inverse_transform(y_predicted) 427 | return masked_mae_loss(y_predicted, y_true) 428 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy>=0.19.0 2 | numpy>=1.12.1 3 | pandas>=0.19.2 4 | pyyaml 5 | statsmodels 6 | wrapt 7 | tensorflow>=1.3.0 8 | torch 9 | tables 10 | future 11 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoshangcs/GTS/8ed45ff1476639f78c382ff09ecca8e60523e7ce/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/eval_baseline_methods.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pandas as pd 4 | 5 | from statsmodels.tsa.vector_ar.var_model import VAR 6 | 7 | from lib import utils 8 | from lib.metrics import masked_rmse_np, masked_mape_np, masked_mae_np 9 | from lib.utils import StandardScaler 10 | 11 | 12 | def historical_average_predict(df, period=12 * 24 * 7, test_ratio=0.2, null_val=0.): 13 | """ 14 | Calculates the historical average of sensor reading. 15 | :param df: 16 | :param period: default 1 week. 17 | :param test_ratio: 18 | :param null_val: default 0. 19 | :return: 20 | """ 21 | n_sample, n_sensor = df.shape 22 | n_test = int(round(n_sample * test_ratio)) 23 | n_train = n_sample - n_test 24 | y_test = df[-n_test:] 25 | y_predict = pd.DataFrame.copy(y_test) 26 | 27 | for i in range(n_train, min(n_sample, n_train + period)): 28 | inds = [j for j in range(i % period, n_train, period)] 29 | historical = df.iloc[inds, :] 30 | y_predict.iloc[i - n_train, :] = historical[historical != null_val].mean() 31 | # Copy each period. 32 | for i in range(n_train + period, n_sample, period): 33 | size = min(period, n_sample - i) 34 | start = i - n_train 35 | y_predict.iloc[start:start + size, :] = y_predict.iloc[start - period: start + size - period, :].values 36 | return y_predict, y_test 37 | 38 | 39 | def static_predict(df, n_forward, test_ratio=0.2): 40 | """ 41 | Assumes $x^{t+1} = x^{t}$ 42 | :param df: 43 | :param n_forward: 44 | :param test_ratio: 45 | :return: 46 | """ 47 | test_num = int(round(df.shape[0] * test_ratio)) 48 | y_test = df[-test_num:] 49 | y_predict = df.shift(n_forward).iloc[-test_num:] 50 | return y_predict, y_test 51 | 52 | 53 | def var_predict(df, n_forwards=(1, 3), n_lags=4, test_ratio=0.2): 54 | """ 55 | Multivariate time series forecasting using Vector Auto-Regressive Model. 56 | :param df: pandas.DataFrame, index: time, columns: sensor id, content: data. 57 | :param n_forwards: a tuple of horizons. 58 | :param n_lags: the order of the VAR model. 59 | :param test_ratio: 60 | :return: [list of prediction in different horizon], dt_test 61 | """ 62 | n_sample, n_output = df.shape 63 | n_test = int(round(n_sample * test_ratio)) 64 | n_train = n_sample - n_test 65 | df_train, df_test = df[:n_train], df[n_train:] 66 | 67 | scaler = StandardScaler(mean=df_train.values.mean(), std=df_train.values.std()) 68 | data = scaler.transform(df_train.values) 69 | var_model = VAR(data) 70 | var_result = var_model.fit(n_lags) 71 | max_n_forwards = np.max(n_forwards) 72 | # Do forecasting. 73 | result = np.zeros(shape=(len(n_forwards), n_test, n_output)) 74 | start = n_train - n_lags - max_n_forwards + 1 75 | for input_ind in range(start, n_sample - n_lags): 76 | prediction = var_result.forecast(scaler.transform(df.values[input_ind: input_ind + n_lags]), max_n_forwards) 77 | for i, n_forward in enumerate(n_forwards): 78 | result_ind = input_ind - n_train + n_lags + n_forward - 1 79 | if 0 <= result_ind < n_test: 80 | result[i, result_ind, :] = prediction[n_forward - 1, :] 81 | 82 | df_predicts = [] 83 | for i, n_forward in enumerate(n_forwards): 84 | df_predict = pd.DataFrame(scaler.inverse_transform(result[i]), index=df_test.index, columns=df_test.columns) 85 | df_predicts.append(df_predict) 86 | return df_predicts, df_test 87 | 88 | 89 | def eval_static(traffic_reading_df): 90 | logger.info('Static') 91 | horizons = [1, 3, 6, 12] 92 | logger.info('\t'.join(['Model', 'Horizon', 'RMSE', 'MAPE', 'MAE'])) 93 | for horizon in horizons: 94 | y_predict, y_test = static_predict(traffic_reading_df, n_forward=horizon, test_ratio=0.2) 95 | rmse = masked_rmse_np(preds=y_predict.as_matrix(), labels=y_test.as_matrix(), null_val=0) 96 | mape = masked_mape_np(preds=y_predict.as_matrix(), labels=y_test.as_matrix(), null_val=0) 97 | mae = masked_mae_np(preds=y_predict.as_matrix(), labels=y_test.as_matrix(), null_val=0) 98 | line = 'Static\t%d\t%.2f\t%.2f\t%.2f' % (horizon, rmse, mape * 100, mae) 99 | logger.info(line) 100 | 101 | 102 | def eval_historical_average(traffic_reading_df, period): 103 | y_predict, y_test = historical_average_predict(traffic_reading_df, period=period, test_ratio=0.2) 104 | rmse = masked_rmse_np(preds=y_predict.as_matrix(), labels=y_test.as_matrix(), null_val=0) 105 | mape = masked_mape_np(preds=y_predict.as_matrix(), labels=y_test.as_matrix(), null_val=0) 106 | mae = masked_mae_np(preds=y_predict.as_matrix(), labels=y_test.as_matrix(), null_val=0) 107 | logger.info('Historical Average') 108 | logger.info('\t'.join(['Model', 'Horizon', 'RMSE', 'MAPE', 'MAE'])) 109 | for horizon in [1, 3, 6, 12]: 110 | line = 'HA\t%d\t%.2f\t%.2f\t%.2f' % (horizon, rmse, mape * 100, mae) 111 | logger.info(line) 112 | 113 | 114 | def eval_var(traffic_reading_df, n_lags=3): 115 | n_forwards = [1, 3, 6, 12] 116 | y_predicts, y_test = var_predict(traffic_reading_df, n_forwards=n_forwards, n_lags=n_lags, 117 | test_ratio=0.2) 118 | logger.info('VAR (lag=%d)' % n_lags) 119 | logger.info('Model\tHorizon\tRMSE\tMAPE\tMAE') 120 | for i, horizon in enumerate(n_forwards): 121 | rmse = masked_rmse_np(preds=y_predicts[i].as_matrix(), labels=y_test.as_matrix(), null_val=0) 122 | mape = masked_mape_np(preds=y_predicts[i].as_matrix(), labels=y_test.as_matrix(), null_val=0) 123 | mae = masked_mae_np(preds=y_predicts[i].as_matrix(), labels=y_test.as_matrix(), null_val=0) 124 | line = 'VAR\t%d\t%.2f\t%.2f\t%.2f' % (horizon, rmse, mape * 100, mae) 125 | logger.info(line) 126 | 127 | 128 | def main(args): 129 | traffic_reading_df = pd.read_hdf(args.traffic_reading_filename) 130 | eval_static(traffic_reading_df) 131 | eval_historical_average(traffic_reading_df, period=7 * 24 * 12) 132 | eval_var(traffic_reading_df, n_lags=3) 133 | 134 | 135 | if __name__ == '__main__': 136 | logger = utils.get_logger('data/model', 'Baseline') 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument('--traffic_reading_filename', default="data/metr-la.h5", type=str, 139 | help='Path to the traffic Dataframe.') 140 | args = parser.parse_args() 141 | main(args) 142 | -------------------------------------------------------------------------------- /scripts/gen_adj_mx.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import numpy as np 7 | import pandas as pd 8 | import pickle 9 | 10 | 11 | def get_adjacency_matrix(distance_df, sensor_ids, normalized_k=0.1): 12 | """ 13 | 14 | :param distance_df: data frame with three columns: [from, to, distance]. 15 | :param sensor_ids: list of sensor ids. 16 | :param normalized_k: entries that become lower than normalized_k after normalization are set to zero for sparsity. 17 | :return: 18 | """ 19 | num_sensors = len(sensor_ids) 20 | dist_mx = np.zeros((num_sensors, num_sensors), dtype=np.float32) 21 | dist_mx[:] = np.inf 22 | # Builds sensor id to index map. 23 | sensor_id_to_ind = {} 24 | for i, sensor_id in enumerate(sensor_ids): 25 | sensor_id_to_ind[sensor_id] = i 26 | 27 | # Fills cells in the matrix with distances. 28 | for row in distance_df.values: 29 | if row[0] not in sensor_id_to_ind or row[1] not in sensor_id_to_ind: 30 | continue 31 | dist_mx[sensor_id_to_ind[row[0]], sensor_id_to_ind[row[1]]] = row[2] 32 | 33 | # Calculates the standard deviation as theta. 34 | distances = dist_mx[~np.isinf(dist_mx)].flatten() 35 | std = distances.std() 36 | adj_mx = np.exp(-np.square(dist_mx / std)) 37 | # Make the adjacent matrix symmetric by taking the max. 38 | # adj_mx = np.maximum.reduce([adj_mx, adj_mx.T]) 39 | 40 | # Sets entries that lower than a threshold, i.e., k, to zero for sparsity. 41 | adj_mx[adj_mx < normalized_k] = 0 42 | return sensor_ids, sensor_id_to_ind, adj_mx 43 | 44 | 45 | if __name__ == '__main__': 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--sensor_ids_filename', type=str, default='data/sensor_graph/graph_sensor_ids.txt', 48 | help='File containing sensor ids separated by comma.') 49 | parser.add_argument('--distances_filename', type=str, default='data/sensor_graph/distances_la_2012.csv', 50 | help='CSV file containing sensor distances with three columns: [from, to, distance].') 51 | parser.add_argument('--normalized_k', type=float, default=0.1, 52 | help='Entries that become lower than normalized_k after normalization are set to zero for sparsity.') 53 | parser.add_argument('--output_pkl_filename', type=str, default='data/sensor_graph/adj_mat.pkl', 54 | help='Path of the output file.') 55 | args = parser.parse_args() 56 | 57 | with open(args.sensor_ids_filename) as f: 58 | sensor_ids = f.read().strip().split(',') 59 | distance_df = pd.read_csv(args.distances_filename, dtype={'from': 'str', 'to': 'str'}) 60 | _, sensor_id_to_ind, adj_mx = get_adjacency_matrix(distance_df, sensor_ids) 61 | # Save to pickle file. 62 | with open(args.output_pkl_filename, 'wb') as f: 63 | pickle.dump([sensor_ids, sensor_id_to_ind, adj_mx], f, protocol=2) 64 | -------------------------------------------------------------------------------- /scripts/generate_training_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import argparse 7 | import numpy as np 8 | import os 9 | import pandas as pd 10 | 11 | 12 | def generate_graph_seq2seq_io_data( 13 | df, x_offsets, y_offsets, add_time_in_day=True, add_day_in_week=False, scaler=None 14 | ): 15 | """ 16 | Generate samples from 17 | :param df: 18 | :param x_offsets: 19 | :param y_offsets: 20 | :param add_time_in_day: 21 | :param add_day_in_week: 22 | :param scaler: 23 | :return: 24 | # x: (epoch_size, input_length, num_nodes, input_dim) 25 | # y: (epoch_size, output_length, num_nodes, output_dim) 26 | """ 27 | 28 | num_samples, num_nodes = df.shape 29 | data = np.expand_dims(df.values, axis=-1) 30 | data_list = [data] 31 | if add_time_in_day: 32 | time_ind = (df.index.values - df.index.values.astype("datetime64[D]")) / np.timedelta64(1, "D") 33 | time_in_day = np.tile(time_ind, [1, num_nodes, 1]).transpose((2, 1, 0)) 34 | data_list.append(time_in_day) 35 | if add_day_in_week: 36 | day_in_week = np.zeros(shape=(num_samples, num_nodes, 7)) 37 | day_in_week[np.arange(num_samples), :, df.index.dayofweek] = 1 38 | data_list.append(day_in_week) 39 | 40 | data = np.concatenate(data_list, axis=-1) 41 | # epoch_len = num_samples + min(x_offsets) - max(y_offsets) 42 | x, y = [], [] 43 | # t is the index of the last observation. 44 | min_t = abs(min(x_offsets)) 45 | max_t = abs(num_samples - abs(max(y_offsets))) # Exclusive 46 | for t in range(min_t, max_t): 47 | x_t = data[t + x_offsets, ...] 48 | y_t = data[t + y_offsets, ...] 49 | x.append(x_t) 50 | y.append(y_t) 51 | x = np.stack(x, axis=0) 52 | y = np.stack(y, axis=0) 53 | return x, y 54 | 55 | 56 | def generate_train_val_test(args): 57 | df = pd.read_hdf(args.traffic_df_filename) 58 | # 0 is the latest observed sample. 59 | x_offsets = np.sort( 60 | # np.concatenate(([-week_size + 1, -day_size + 1], np.arange(-11, 1, 1))) 61 | np.concatenate((np.arange(-11, 1, 1),)) 62 | ) 63 | # Predict the next one hour 64 | y_offsets = np.sort(np.arange(1, 13, 1)) 65 | # x: (num_samples, input_length, num_nodes, input_dim) 66 | # y: (num_samples, output_length, num_nodes, output_dim) 67 | x, y = generate_graph_seq2seq_io_data( 68 | df, 69 | x_offsets=x_offsets, 70 | y_offsets=y_offsets, 71 | add_time_in_day=True, 72 | add_day_in_week=False, 73 | ) 74 | 75 | print("x shape: ", x.shape, ", y shape: ", y.shape) 76 | # Write the data into npz file. 77 | # num_test = 6831, using the last 6831 examples as testing. 78 | # for the rest: 7/8 is used for training, and 1/8 is used for validation. 79 | num_samples = x.shape[0] 80 | num_test = round(num_samples * 0.2) 81 | num_train = round(num_samples * 0.7) 82 | num_val = num_samples - num_test - num_train 83 | 84 | # train 85 | x_train, y_train = x[:num_train], y[:num_train] 86 | # val 87 | x_val, y_val = ( 88 | x[num_train: num_train + num_val], 89 | y[num_train: num_train + num_val], 90 | ) 91 | # test 92 | x_test, y_test = x[-num_test:], y[-num_test:] 93 | 94 | for cat in ["train", "val", "test"]: 95 | _x, _y = locals()["x_" + cat], locals()["y_" + cat] 96 | print(cat, "x: ", _x.shape, "y:", _y.shape) 97 | np.savez_compressed( 98 | os.path.join(args.output_dir, "%s.npz" % cat), 99 | x=_x, 100 | y=_y, 101 | x_offsets=x_offsets.reshape(list(x_offsets.shape) + [1]), 102 | y_offsets=y_offsets.reshape(list(y_offsets.shape) + [1]), 103 | ) 104 | 105 | 106 | def main(args): 107 | print("Generating training data") 108 | generate_train_val_test(args) 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument( 114 | "--output_dir", type=str, default="data/", help="Output directory." 115 | ) 116 | parser.add_argument( 117 | "--traffic_df_filename", 118 | type=str, 119 | default="data/metr-la.h5", 120 | help="Raw traffic readings.", 121 | ) 122 | args = parser.parse_args() 123 | main(args) 124 | -------------------------------------------------------------------------------- /scripts/generate_visualization_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import argparse 7 | import numpy as np 8 | import os 9 | import pandas as pd 10 | 11 | 12 | def generate_graph_seq2seq_io_data( 13 | df, x_offsets, y_offsets, add_time_in_day=True, add_day_in_week=False, scaler=None 14 | ): 15 | """ 16 | Generate samples from 17 | :param df: 18 | :param x_offsets: 19 | :param y_offsets: 20 | :param add_time_in_day: 21 | :param add_day_in_week: 22 | :param scaler: 23 | :return: 24 | # x: (epoch_size, input_length, num_nodes, input_dim) 25 | # y: (epoch_size, output_length, num_nodes, output_dim) 26 | """ 27 | 28 | num_samples, num_nodes = df.shape 29 | data = np.expand_dims(df.values, axis=-1) 30 | data_list = [data] 31 | if add_time_in_day: 32 | time_ind = (df.index.values - df.index.values.astype("datetime64[D]")) / np.timedelta64(1, "D") 33 | time_in_day = np.tile(time_ind, [1, num_nodes, 1]).transpose((2, 1, 0)) 34 | data_list.append(time_in_day) 35 | if add_day_in_week: 36 | day_in_week = np.zeros(shape=(num_samples, num_nodes, 7)) 37 | day_in_week[np.arange(num_samples), :, df.index.dayofweek] = 1 38 | data_list.append(day_in_week) 39 | 40 | data = np.concatenate(data_list, axis=-1) 41 | index = np.array(df.index) 42 | # epoch_len = num_samples + min(x_offsets) - max(y_offsets) 43 | x, y = [], [] 44 | indexes = [] 45 | # t is the index of the last observation. 46 | min_t = abs(min(x_offsets)) 47 | max_t = abs(num_samples - abs(max(y_offsets))) # Exclusive 48 | for t in range(min_t, max_t, 12): 49 | x_t = data[t + x_offsets, ...] 50 | y_t = data[t + y_offsets, ...] 51 | x.append(x_t) 52 | y.append(y_t) 53 | ind = index[t + y_offsets] 54 | indexes.append(ind) 55 | x = np.stack(x, axis=0) 56 | y = np.stack(y, axis=0) 57 | indexes = np.stack(indexes, axis=0) 58 | return x, y, indexes 59 | 60 | 61 | def generate_train_val_test(args): 62 | df = pd.read_hdf(args.traffic_df_filename) 63 | # 0 is the latest observed sample. 64 | x_offsets = np.sort( 65 | # np.concatenate(([-week_size + 1, -day_size + 1], np.arange(-11, 1, 1))) 66 | np.concatenate((np.arange(-11, 1, 1),)) 67 | ) 68 | # Predict the next one hour 69 | y_offsets = np.sort(np.arange(1, 13, 1)) 70 | # x: (num_samples, input_length, num_nodes, input_dim) 71 | # y: (num_samples, output_length, num_nodes, output_dim) 72 | x, y, index = generate_graph_seq2seq_io_data( 73 | df, 74 | x_offsets=x_offsets, 75 | y_offsets=y_offsets, 76 | add_time_in_day=True, 77 | add_day_in_week=False, 78 | ) 79 | time = df.index 80 | 81 | print("x shape: ", x.shape, ", y shape: ", y.shape) 82 | # Write the data into npz file. 83 | # num_test = 6831, using the last 6831 examples as testing. 84 | # for the rest: 7/8 is used for training, and 1/8 is used for validation. 85 | num_samples = time.shape[0] 86 | num_test = round(num_samples * 0.2) 87 | num_train = round(num_samples * 0.7) 88 | num_val = num_samples - num_test - num_train 89 | 90 | # train 91 | x_train, y_train, time_train = x[:num_train], y[:num_train], index[:num_train] 92 | # val 93 | x_val, y_val, time_val = ( 94 | x[num_train: num_train + num_val], 95 | y[num_train: num_train + num_val], 96 | index[num_train: num_train + num_val], 97 | ) 98 | # test 99 | x_test, y_test, time_test = x[-num_test:], y[-num_test:], index[-num_test:] 100 | 101 | for cat in ["train", "val", "test"]: 102 | _x, _y, _time = locals()["x_" + cat], locals()["y_" + cat], locals()["time_" + cat] 103 | print(cat, "x: ", _x.shape, "y:", _y.shape) 104 | np.savez_compressed( 105 | os.path.join(args.output_dir, "%s.npz" % cat), 106 | x=_x, 107 | y=_y, 108 | time = _time, 109 | x_offsets=x_offsets.reshape(list(x_offsets.shape) + [1]), 110 | y_offsets=y_offsets.reshape(list(y_offsets.shape) + [1]), 111 | time_offsets=y_offsets.reshape(list(y_offsets.shape) + [1]), 112 | ) 113 | 114 | 115 | def main(args): 116 | print("Generating training data") 117 | generate_train_val_test(args) 118 | 119 | 120 | if __name__ == "__main__": 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument( 123 | "--output_dir", type=str, default="../data/la_time", help="Output directory." 124 | ) 125 | parser.add_argument( 126 | "--traffic_df_filename", 127 | type=str, 128 | default="../data/metr-la.h5", 129 | help="Raw traffic readings.", 130 | ) 131 | args = parser.parse_args() 132 | main(args) 133 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import yaml 7 | from model.pytorch.supervisor import GTSSupervisor 8 | from lib.utils import load_graph_data 9 | 10 | def main(args): 11 | with open(args.config_filename) as f: 12 | supervisor_config = yaml.load(f) 13 | save_adj_name = args.config_filename[11:-5] 14 | supervisor = GTSSupervisor(save_adj_name, temperature=args.temperature, **supervisor_config) 15 | supervisor.train() 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--config_filename', default='data/model/para_la.yaml', type=str, 20 | help='Configuration filename for restoring the model.') 21 | parser.add_argument('--use_cpu_only', default=False, type=bool, help='Set to true to only use cpu.') 22 | parser.add_argument('--temperature', default=0.5, type=float, help='temperature value for gumbel-softmax.') 23 | args = parser.parse_args() 24 | main(args) --------------------------------------------------------------------------------