├── .gitignore ├── DCRNN_CPU ├── LICENSE ├── README.md ├── data ├── dcrnn_predictions_pytorch.npz ├── model │ ├── dcrnn_bay.yaml │ ├── dcrnn_la.yaml │ ├── dcrnn_test_config.yaml │ └── pretrained │ │ ├── METR-LA │ │ ├── config.yaml │ │ ├── models-2.7422-24375.data-00000-of-00001 │ │ └── models-2.7422-24375.index │ │ └── PEMS-BAY │ │ ├── config.yaml │ │ ├── events.out.tfevents.1547170277.kakarot │ │ ├── models-1.6139-30780.data-00000-of-00001 │ │ └── models-1.6139-30780.index └── sensor_graph │ ├── adj_mx.pkl │ ├── adj_mx_bay.pkl │ ├── distances_la_2012.csv │ ├── graph_sensor_ids.txt │ └── graph_sensor_locations.csv ├── dcrnn_train.py ├── dcrnn_train_pytorch.py ├── figures ├── model_architecture.jpg ├── result1.png ├── result2.png ├── result3.png └── result4.png ├── lib ├── AMSGrad.py ├── __init__.py ├── metrics.py ├── metrics_test.py └── utils.py ├── model ├── __init__.py ├── pytorch │ ├── __init__.py │ ├── dcrnn_cell.py │ ├── dcrnn_model.py │ ├── dcrnn_supervisor.py │ └── loss.py └── tf │ ├── __init__.py │ ├── dcrnn_cell.py │ ├── dcrnn_model.py │ └── dcrnn_supervisor.py ├── requirements.txt ├── run_demo.py ├── run_demo_pytorch.py └── scripts ├── __init__.py ├── eval_baseline_methods.py ├── gen_adj_mx.py └── generate_training_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # dotenv 82 | .env 83 | 84 | # virtualenv 85 | .venv 86 | venv/ 87 | ENV/ 88 | 89 | # Spyder project settings 90 | .spyderproject 91 | .spyproject 92 | 93 | # Rope project settings 94 | .ropeproject 95 | 96 | # mkdocs documentation 97 | /site 98 | 99 | # mypy 100 | .mypy_cache/ 101 | 102 | # pycharm 103 | .idea/ 104 | -------------------------------------------------------------------------------- /DCRNN_CPU: -------------------------------------------------------------------------------- 1 | FROM ufoym/deepo:cpu 2 | COPY requirements.txt . 3 | RUN pip install -r requirements.txt 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Yaguang Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting 2 | 3 | ![Diffusion Convolutional Recurrent Neural Network](figures/model_architecture.jpg "Model Architecture") 4 | 5 | This is a PyTorch implementation of Diffusion Convolutional Recurrent Neural Network in the following paper: \ 6 | Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu, [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting](https://arxiv.org/abs/1707.01926), ICLR 2018. 7 | 8 | 9 | ## Requirements 10 | * torch 11 | * scipy>=0.19.0 12 | * numpy>=1.12.1 13 | * pandas>=0.19.2 14 | * pyyaml 15 | * statsmodels 16 | * tensorflow>=1.3.0 17 | * torch 18 | * tables 19 | * future 20 | 21 | Dependency can be installed using the following command: 22 | ```bash 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ### Comparison with Tensorflow implementation 27 | 28 | In MAE (For LA dataset, PEMS-BAY coming in a while) 29 | 30 | | Horizon | Tensorflow | Pytorch | 31 | |:--------|:--------:|:--------:| 32 | | 1 Hour | 3.69 | 3.12 | 33 | | 30 Min | 3.15 | 2.82 | 34 | | 15 Min | 2.77 | 2.56 | 35 | 36 | 37 | ## Data Preparation 38 | The traffic data files for Los Angeles (METR-LA) and the Bay Area (PEMS-BAY), i.e., `metr-la.h5` and `pems-bay.h5`, are available at [Google Drive](https://drive.google.com/open?id=10FOTa6HXPqX8Pf5WRoRwcFnW9BrNZEIX) or [Baidu Yun](https://pan.baidu.com/s/14Yy9isAIZYdU__OYEQGa_g), and should be 39 | put into the `data/` folder. 40 | The `*.h5` files store the data in `panads.DataFrame` using the `HDF5` file format. Here is an example: 41 | 42 | | | sensor_0 | sensor_1 | sensor_2 | sensor_n | 43 | |:-------------------:|:--------:|:--------:|:--------:|:--------:| 44 | | 2018/01/01 00:00:00 | 60.0 | 65.0 | 70.0 | ... | 45 | | 2018/01/01 00:05:00 | 61.0 | 64.0 | 65.0 | ... | 46 | | 2018/01/01 00:10:00 | 63.0 | 65.0 | 60.0 | ... | 47 | | ... | ... | ... | ... | ... | 48 | 49 | 50 | Here is an article about [Using HDF5 with Python](https://medium.com/@jerilkuriakose/using-hdf5-with-python-6c5242d08773). 51 | 52 | Run the following commands to generate train/test/val dataset at `data/{METR-LA,PEMS-BAY}/{train,val,test}.npz`. 53 | ```bash 54 | # Create data directories 55 | mkdir -p data/{METR-LA,PEMS-BAY} 56 | 57 | # METR-LA 58 | python -m scripts.generate_training_data --output_dir=data/METR-LA --traffic_df_filename=data/metr-la.h5 59 | 60 | # PEMS-BAY 61 | python -m scripts.generate_training_data --output_dir=data/PEMS-BAY --traffic_df_filename=data/pems-bay.h5 62 | ``` 63 | 64 | ## Graph Construction 65 | As the currently implementation is based on pre-calculated road network distances between sensors, it currently only 66 | supports sensor ids in Los Angeles (see `data/sensor_graph/sensor_info_201206.csv`). 67 | ```bash 68 | python -m scripts.gen_adj_mx --sensor_ids_filename=data/sensor_graph/graph_sensor_ids.txt --normalized_k=0.1\ 69 | --output_pkl_filename=data/sensor_graph/adj_mx.pkl 70 | ``` 71 | Besides, the locations of sensors in Los Angeles, i.e., METR-LA, are available at [data/sensor_graph/graph_sensor_locations.csv](https://github.com/liyaguang/DCRNN/blob/master/data/sensor_graph/graph_sensor_locations.csv). 72 | 73 | ## Run the Pre-trained Model on METR-LA 74 | 75 | ```bash 76 | # METR-LA 77 | python run_demo_pytorch.py --config_filename=data/model/pretrained/METR-LA/config.yaml 78 | 79 | # PEMS-BAY 80 | python run_demo_pytorch.py --config_filename=data/model/pretrained/PEMS-BAY/config.yaml 81 | ``` 82 | The generated prediction of DCRNN is in `data/results/dcrnn_predictions`. 83 | 84 | 85 | ## Model Training 86 | ```bash 87 | # METR-LA 88 | python dcrnn_train_pytorch.py --config_filename=data/model/dcrnn_la.yaml 89 | 90 | # PEMS-BAY 91 | python dcrnn_train_pytorch.py --config_filename=data/model/dcrnn_bay.yaml 92 | ``` 93 | 94 | There is a chance that the training loss will explode, the temporary workaround is to restart from the last saved model before the explosion, or to decrease the learning rate earlier in the learning rate schedule. 95 | 96 | 97 | ## Eval baseline methods 98 | ```bash 99 | # METR-LA 100 | python -m scripts.eval_baseline_methods --traffic_reading_filename=data/metr-la.h5 101 | ``` 102 | 103 | ### PyTorch Results 104 | 105 | ![PyTorch Results](figures/result1.png "PyTorch Results") 106 | 107 | ![PyTorch Results](figures/result2.png "PyTorch Results") 108 | 109 | ![PyTorch Results](figures/result3.png "PyTorch Results") 110 | 111 | ![PyTorch Results](figures/result4.png "PyTorch Results") 112 | 113 | ## Citation 114 | 115 | If you find this repository, e.g., the code and the datasets, useful in your research, please cite the following paper: 116 | ``` 117 | @inproceedings{li2018dcrnn_traffic, 118 | title={Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting}, 119 | author={Li, Yaguang and Yu, Rose and Shahabi, Cyrus and Liu, Yan}, 120 | booktitle={International Conference on Learning Representations (ICLR '18)}, 121 | year={2018} 122 | } 123 | ``` 124 | -------------------------------------------------------------------------------- /data/dcrnn_predictions_pytorch.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/data/dcrnn_predictions_pytorch.npz -------------------------------------------------------------------------------- /data/model/dcrnn_bay.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | base_dir: data/model 3 | log_level: INFO 4 | data: 5 | batch_size: 64 6 | dataset_dir: data/PEMS-BAY 7 | test_batch_size: 64 8 | val_batch_size: 64 9 | graph_pkl_filename: data/sensor_graph/adj_mx_bay.pkl 10 | 11 | model: 12 | cl_decay_steps: 2000 13 | filter_type: dual_random_walk 14 | horizon: 12 15 | input_dim: 2 16 | l1_decay: 0 17 | max_diffusion_step: 2 18 | num_nodes: 325 19 | num_rnn_layers: 2 20 | output_dim: 1 21 | rnn_units: 64 22 | seq_len: 12 23 | use_curriculum_learning: true 24 | 25 | train: 26 | base_lr: 0.01 27 | dropout: 0 28 | epoch: 0 29 | epochs: 100 30 | epsilon: 1.0e-3 31 | global_step: 0 32 | lr_decay_ratio: 0.1 33 | max_grad_norm: 5 34 | max_to_keep: 100 35 | min_learning_rate: 2.0e-06 36 | optimizer: adam 37 | patience: 50 38 | steps: [20, 30, 40, 50] 39 | test_every_n_epochs: 10 40 | -------------------------------------------------------------------------------- /data/model/dcrnn_la.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | base_dir: data/model 3 | log_level: INFO 4 | data: 5 | batch_size: 64 6 | dataset_dir: data/METR-LA 7 | test_batch_size: 64 8 | val_batch_size: 64 9 | graph_pkl_filename: data/sensor_graph/adj_mx.pkl 10 | 11 | model: 12 | cl_decay_steps: 2000 13 | filter_type: dual_random_walk 14 | horizon: 12 15 | input_dim: 2 16 | l1_decay: 0 17 | max_diffusion_step: 2 18 | num_nodes: 207 19 | num_rnn_layers: 2 20 | output_dim: 1 21 | rnn_units: 64 22 | seq_len: 12 23 | use_curriculum_learning: true 24 | 25 | train: 26 | base_lr: 0.01 27 | dropout: 0 28 | epoch: 0 29 | epochs: 100 30 | epsilon: 1.0e-3 31 | global_step: 0 32 | lr_decay_ratio: 0.1 33 | max_grad_norm: 5 34 | max_to_keep: 100 35 | min_learning_rate: 2.0e-06 36 | optimizer: adam 37 | patience: 50 38 | steps: [20, 30, 40, 50] 39 | test_every_n_epochs: 10 -------------------------------------------------------------------------------- /data/model/dcrnn_test_config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | base_dir: data/model 3 | log_level: INFO 4 | data: 5 | batch_size: 64 6 | dataset_dir: data/METR-LA 7 | test_batch_size: 64 8 | val_batch_size: 64 9 | graph_pkl_filename: data/sensor_graph/adj_mx.pkl 10 | 11 | model: 12 | cl_decay_steps: 2000 13 | filter_type: dual_random_walk 14 | horizon: 12 15 | input_dim: 2 16 | l1_decay: 0 17 | max_diffusion_step: 2 18 | num_nodes: 207 19 | num_rnn_layers: 2 20 | output_dim: 1 21 | rnn_units: 64 22 | seq_len: 12 23 | use_curriculum_learning: true 24 | 25 | train: 26 | base_lr: 0.01 27 | dropout: 0 28 | epoch: 51 29 | epochs: 100 30 | epsilon: 1.0e-3 31 | global_step: 0 32 | lr_decay_ratio: 0.1 33 | max_grad_norm: 5 34 | max_to_keep: 100 35 | min_learning_rate: 2.0e-06 36 | optimizer: adam 37 | patience: 50 38 | steps: [20, 30, 40, 50] 39 | test_every_n_epochs: 10 -------------------------------------------------------------------------------- /data/model/pretrained/METR-LA/config.yaml: -------------------------------------------------------------------------------- 1 | base_dir: data/model 2 | log_level: INFO 3 | data: 4 | batch_size: 64 5 | dataset_dir: data/METR-LA 6 | graph_pkl_filename: data/sensor_graph/adj_mx.pkl 7 | test_batch_size: 64 8 | model: 9 | cl_decay_steps: 2000 10 | filter_type: dual_random_walk 11 | horizon: 12 12 | input_dim: 2 13 | l1_decay: 0 14 | max_diffusion_step: 2 15 | num_nodes: 207 16 | num_rnn_layers: 2 17 | output_dim: 1 18 | rnn_units: 64 19 | seq_len: 12 20 | use_curriculum_learning: true 21 | train: 22 | base_lr: 0.01 23 | dropout: 0 24 | epoch: 64 25 | epochs: 100 26 | epsilon: 0.001 27 | global_step: 24375 28 | log_dir: data/model/pretrained/METR-LA 29 | lr_decay_ratio: 0.1 30 | max_grad_norm: 5 31 | max_to_keep: 100 32 | min_learning_rate: 2.0e-06 33 | model_filename: data/model/pretrained/METR-LA/models-2.7422-24375 34 | optimizer: adam 35 | patience: 50 36 | steps: 37 | - 20 38 | - 30 39 | - 40 40 | - 50 41 | test_every_n_epochs: 10 42 | -------------------------------------------------------------------------------- /data/model/pretrained/METR-LA/models-2.7422-24375.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/data/model/pretrained/METR-LA/models-2.7422-24375.data-00000-of-00001 -------------------------------------------------------------------------------- /data/model/pretrained/METR-LA/models-2.7422-24375.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/data/model/pretrained/METR-LA/models-2.7422-24375.index -------------------------------------------------------------------------------- /data/model/pretrained/PEMS-BAY/config.yaml: -------------------------------------------------------------------------------- 1 | base_dir: data/model 2 | data: 3 | batch_size: 64 4 | dataset_dir: data/PEMS-BAY 5 | graph_pkl_filename: data/sensor_graph/adj_mx_bay.pkl 6 | test_batch_size: 64 7 | val_batch_size: 64 8 | log_level: INFO 9 | model: 10 | cl_decay_steps: 2000 11 | filter_type: dual_random_walk 12 | horizon: 12 13 | input_dim: 2 14 | l1_decay: 0 15 | max_diffusion_step: 2 16 | num_nodes: 325 17 | num_rnn_layers: 2 18 | output_dim: 1 19 | rnn_units: 64 20 | seq_len: 12 21 | use_curriculum_learning: true 22 | train: 23 | base_lr: 0.01 24 | dropout: 0 25 | epoch: 53 26 | epochs: 100 27 | epsilon: 0.001 28 | global_step: 30780 29 | log_dir: data/model/pretrained/PEMS-BAY/ 30 | lr_decay_ratio: 0.1 31 | max_grad_norm: 5 32 | max_to_keep: 100 33 | min_learning_rate: 2.0e-06 34 | model_filename: data/model/pretrained/PEMS-BAY/models-1.6139-30780 35 | optimizer: adam 36 | patience: 50 37 | steps: 38 | - 20 39 | - 30 40 | - 40 41 | - 50 42 | test_every_n_epochs: 10 43 | -------------------------------------------------------------------------------- /data/model/pretrained/PEMS-BAY/events.out.tfevents.1547170277.kakarot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/data/model/pretrained/PEMS-BAY/events.out.tfevents.1547170277.kakarot -------------------------------------------------------------------------------- /data/model/pretrained/PEMS-BAY/models-1.6139-30780.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/data/model/pretrained/PEMS-BAY/models-1.6139-30780.data-00000-of-00001 -------------------------------------------------------------------------------- /data/model/pretrained/PEMS-BAY/models-1.6139-30780.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/data/model/pretrained/PEMS-BAY/models-1.6139-30780.index -------------------------------------------------------------------------------- /data/sensor_graph/graph_sensor_ids.txt: -------------------------------------------------------------------------------- 1 | 773869,767541,767542,717447,717446,717445,773062,767620,737529,717816,765604,767471,716339,773906,765273,716331,771667,716337,769953,769402,769403,769819,769405,716941,717578,716960,717804,767572,767573,773012,773013,764424,769388,716328,717819,769941,760987,718204,718045,769418,768066,772140,773927,760024,774012,774011,767609,769359,760650,716956,769831,761604,717495,716554,773953,767470,716955,764949,773954,767366,769444,773939,774067,769443,767750,767751,767610,773880,764766,717497,717490,717491,717492,717493,765176,717498,717499,765171,718064,718066,765164,769431,769430,717610,767053,767621,772596,772597,767350,767351,716571,773023,767585,773024,717483,718379,717481,717480,717486,764120,772151,718371,717489,717488,717818,718076,718072,767455,767454,761599,717099,773916,716968,769467,717576,717573,717572,717571,717570,764760,718089,769847,717608,767523,716942,718090,769867,717472,717473,759591,764781,765099,762329,716953,716951,767509,765182,769358,772513,716958,718496,769346,773904,718499,764853,761003,717502,759602,717504,763995,717508,765265,773996,773995,717469,717468,764106,717465,764794,717466,717461,717460,717463,717462,769345,716943,772669,717582,717583,717580,716949,717587,772178,717585,716939,768469,764101,767554,773975,773974,717510,717513,717825,767495,767494,717821,717823,717458,717459,769926,764858,717450,717452,717453,759772,717456,771673,772167,769372,774204,769806,717590,717592,717595,772168,718141,769373 -------------------------------------------------------------------------------- /data/sensor_graph/graph_sensor_locations.csv: -------------------------------------------------------------------------------- 1 | index,sensor_id,latitude,longitude 2 | 0,773869,34.15497,-118.31829 3 | 1,767541,34.11621,-118.23799 4 | 2,767542,34.11641,-118.23819 5 | 3,717447,34.07248,-118.26772 6 | 4,717446,34.07142,-118.26572 7 | 5,717445,34.06913,-118.25932 8 | 6,773062,34.05368,-118.23369 9 | 7,767620,34.13486,-118.22932 10 | 8,737529,34.20264,-118.47352 11 | 9,717816,34.15562,-118.46860 12 | 10,765604,34.16415,-118.38223 13 | 11,767471,34.15691,-118.22469 14 | 12,716339,34.07821,-118.28795 15 | 13,773906,34.15660,-118.30266 16 | 14,765273,34.18949,-118.47437 17 | 15,716331,34.07006,-118.26246 18 | 16,771667,34.07314,-118.23388 19 | 17,716337,34.07732,-118.28186 20 | 18,769953,34.20672,-118.19992 21 | 19,769402,34.12095,-118.33911 22 | 20,769403,34.12073,-118.33928 23 | 21,769819,34.20584,-118.19803 24 | 22,769405,34.12634,-118.34482 25 | 23,716941,34.05767,-118.21435 26 | 24,717578,34.15478,-118.27076 27 | 25,716960,34.12121,-118.27164 28 | 26,717804,34.09478,-118.47605 29 | 27,767572,34.12967,-118.22871 30 | 28,767573,34.12964,-118.22901 31 | 29,773012,34.08390,-118.22086 32 | 30,773013,34.08374,-118.22076 33 | 31,764424,34.17878,-118.39469 34 | 32,769388,34.11027,-118.33441 35 | 33,716328,34.06664,-118.25397 36 | 34,717819,34.18784,-118.47407 37 | 35,769941,34.20699,-118.20237 38 | 36,760987,34.15359,-118.34043 39 | 37,718204,34.15541,-118.29575 40 | 38,718045,34.06712,-118.23973 41 | 39,769418,34.12659,-118.34465 42 | 40,768066,34.15118,-118.37480 43 | 41,772140,34.16498,-118.47493 44 | 42,773927,34.15262,-118.28034 45 | 43,760024,34.15930,-118.46483 46 | 44,774012,34.14769,-118.20137 47 | 45,774011,34.14747,-118.20123 48 | 46,767609,34.18555,-118.21733 49 | 47,769359,34.15660,-118.42216 50 | 48,760650,34.07505,-118.23256 51 | 49,716956,34.10658,-118.25544 52 | 50,769831,34.20663,-118.20101 53 | 51,761604,34.15407,-118.28711 54 | 52,717495,34.15664,-118.41326 55 | 53,716554,34.15597,-118.26660 56 | 54,773953,34.15522,-118.29344 57 | 55,767470,34.15699,-118.22436 58 | 56,716955,34.09579,-118.24427 59 | 57,764949,34.12881,-118.34684 60 | 58,773954,34.15544,-118.29344 61 | 59,767366,34.21216,-118.47341 62 | 60,769444,34.15555,-118.43908 63 | 61,773939,34.15297,-118.37226 64 | 62,774067,34.15362,-118.28441 65 | 63,769443,34.15574,-118.43931 66 | 64,767750,34.09335,-118.20635 67 | 65,767751,34.09335,-118.20616 68 | 66,767610,34.18555,-118.21766 69 | 67,773880,34.15367,-118.34840 70 | 68,764766,34.13338,-118.35350 71 | 69,717497,34.15685,-118.41456 72 | 70,717490,34.14745,-118.37124 73 | 71,717491,34.14761,-118.37110 74 | 72,717492,34.15459,-118.37935 75 | 73,717493,34.15434,-118.39618 76 | 74,765176,34.13286,-118.35135 77 | 75,717498,34.15571,-118.43273 78 | 76,717499,34.15666,-118.44808 79 | 77,765171,34.16789,-118.46896 80 | 78,718064,34.11296,-118.24489 81 | 79,718066,34.12302,-118.22889 82 | 80,765164,34.07898,-118.28911 83 | 81,769431,34.15843,-118.45664 84 | 82,769430,34.15818,-118.45658 85 | 83,717610,34.17886,-118.39497 86 | 84,767053,34.17091,-118.46775 87 | 85,767621,34.13486,-118.22969 88 | 86,772596,34.17126,-118.50495 89 | 87,772597,34.17109,-118.50495 90 | 88,767350,34.18011,-118.47045 91 | 89,767351,34.18022,-118.47022 92 | 90,716571,34.20000,-118.40337 93 | 91,773023,34.05773,-118.24348 94 | 92,767585,34.16556,-118.22432 95 | 93,773024,34.05759,-118.24357 96 | 94,717483,34.11684,-118.33698 97 | 95,718379,34.14224,-118.27812 98 | 96,717481,34.10634,-118.32826 99 | 97,717480,34.10478,-118.32497 100 | 98,717486,34.12974,-118.34809 101 | 99,764120,34.20164,-118.40366 102 | 100,772151,34.16928,-118.49872 103 | 101,718371,34.09017,-118.23849 104 | 102,717489,34.13876,-118.36438 105 | 103,717488,34.13561,-118.36006 106 | 104,717818,34.17220,-118.46753 107 | 105,718076,34.16339,-118.22530 108 | 106,718072,34.14910,-118.22570 109 | 107,767455,34.14347,-118.22704 110 | 108,767454,34.14352,-118.22733 111 | 109,761599,34.14226,-118.27786 112 | 110,717099,34.15648,-118.24674 113 | 111,773916,34.15247,-118.28520 114 | 112,716968,34.16588,-118.29809 115 | 113,769467,34.15451,-118.39699 116 | 114,717576,34.15559,-118.29570 117 | 115,717573,34.15384,-118.32500 118 | 116,717572,34.15351,-118.32751 119 | 117,717571,34.15326,-118.35921 120 | 118,717570,34.15302,-118.35921 121 | 119,764760,34.13401,-118.35506 122 | 120,718089,34.12769,-118.27372 123 | 121,769847,34.20983,-118.22351 124 | 122,717608,34.17154,-118.38812 125 | 123,767523,34.11439,-118.24209 126 | 124,716942,34.05930,-118.21451 127 | 125,718090,34.14847,-118.27969 128 | 126,769867,34.21846,-118.23931 129 | 127,717472,34.10045,-118.31601 130 | 128,717473,34.10054,-118.31581 131 | 129,759591,34.11521,-118.26825 132 | 130,764781,34.16037,-118.47012 133 | 131,765099,34.16378,-118.47224 134 | 132,762329,34.13904,-118.22862 135 | 133,716953,34.09458,-118.24279 136 | 134,716951,34.08581,-118.23182 137 | 135,767509,34.11059,-118.24819 138 | 136,765182,34.06491,-118.25126 139 | 137,769358,34.15679,-118.42222 140 | 138,772513,34.06871,-118.23661 141 | 139,716958,34.11167,-118.26501 142 | 140,718496,34.15403,-118.34232 143 | 141,769346,34.15677,-118.40424 144 | 142,773904,34.15641,-118.30266 145 | 143,718499,34.15469,-118.31253 146 | 144,764853,34.06461,-118.25102 147 | 145,761003,34.15546,-118.30841 148 | 146,717502,34.16521,-118.47484 149 | 147,759602,34.12199,-118.27178 150 | 148,717504,34.16519,-118.49166 151 | 149,763995,34.21979,-118.40931 152 | 150,717508,34.17112,-118.51814 153 | 151,765265,34.18529,-118.47395 154 | 152,773996,34.14511,-118.21587 155 | 153,773995,34.14483,-118.21587 156 | 154,717469,34.09710,-118.31366 157 | 155,717468,34.09699,-118.31381 158 | 156,764106,34.17169,-118.38801 159 | 157,717465,34.09373,-118.30907 160 | 158,764794,34.16028,-118.46808 161 | 159,717466,34.09359,-118.30918 162 | 160,717461,34.08558,-118.30174 163 | 161,717460,34.08571,-118.30161 164 | 162,717463,34.09004,-118.30590 165 | 163,717462,34.08993,-118.30607 166 | 164,769345,34.15655,-118.40441 167 | 165,716943,34.05987,-118.21492 168 | 166,772669,34.07828,-118.22834 169 | 167,717582,34.15646,-118.26092 170 | 168,717583,34.15627,-118.25506 171 | 169,717580,34.15620,-118.26359 172 | 170,716949,34.08406,-118.22974 173 | 171,717587,34.15402,-118.23893 174 | 172,772178,34.16903,-118.49885 175 | 173,717585,34.15564,-118.24188 176 | 174,716939,34.04301,-118.21724 177 | 175,768469,34.13583,-118.35993 178 | 176,764101,34.16421,-118.38246 179 | 177,767554,34.11966,-118.23143 180 | 178,773975,34.14584,-118.22251 181 | 179,773974,34.14559,-118.22251 182 | 180,717510,34.17128,-118.51976 183 | 181,717513,34.17339,-118.53680 184 | 182,717825,34.22164,-118.47307 185 | 183,767495,34.10377,-118.24992 186 | 184,767494,34.10377,-118.24962 187 | 185,717821,34.20112,-118.47361 188 | 186,717823,34.20264,-118.47326 189 | 187,717458,34.08265,-118.29755 190 | 188,717459,34.08294,-118.29729 191 | 189,769926,34.21356,-118.23113 192 | 190,764858,34.15270,-118.37540 193 | 191,717450,34.07488,-118.27362 194 | 192,717452,34.07502,-118.27356 195 | 193,717453,34.07696,-118.28093 196 | 194,759772,34.17115,-118.30539 197 | 195,717456,34.08102,-118.29325 198 | 196,771673,34.07787,-118.22871 199 | 197,772167,34.16526,-118.47985 200 | 198,769372,34.10270,-118.31730 201 | 199,774204,34.15397,-118.34172 202 | 200,769806,34.19638,-118.18442 203 | 201,717590,34.14929,-118.23182 204 | 202,717592,34.14604,-118.22430 205 | 203,717595,34.14163,-118.18290 206 | 204,772168,34.16542,-118.47985 207 | 205,718141,34.15133,-118.37456 208 | 206,769373,34.10262,-118.31747 -------------------------------------------------------------------------------- /dcrnn_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import tensorflow as tf 7 | import yaml 8 | 9 | from lib.utils import load_graph_data 10 | from model.tf.dcrnn_supervisor import DCRNNSupervisor 11 | 12 | 13 | def main(args): 14 | with open(args.config_filename) as f: 15 | supervisor_config = yaml.load(f) 16 | 17 | graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename') 18 | sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(graph_pkl_filename) 19 | 20 | tf_config = tf.ConfigProto() 21 | if args.use_cpu_only: 22 | tf_config = tf.ConfigProto(device_count={'GPU': 0}) 23 | tf_config.gpu_options.allow_growth = True 24 | with tf.Session(config=tf_config) as sess: 25 | supervisor = DCRNNSupervisor(adj_mx=adj_mx, **supervisor_config) 26 | 27 | supervisor.train(sess=sess) 28 | 29 | 30 | if __name__ == '__main__': 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--config_filename', default=None, type=str, 33 | help='Configuration filename for restoring the model.') 34 | parser.add_argument('--use_cpu_only', default=False, type=bool, help='Set to true to only use cpu.') 35 | args = parser.parse_args() 36 | main(args) 37 | -------------------------------------------------------------------------------- /dcrnn_train_pytorch.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import yaml 7 | 8 | from lib.utils import load_graph_data 9 | from model.pytorch.dcrnn_supervisor import DCRNNSupervisor 10 | 11 | 12 | def main(args): 13 | with open(args.config_filename) as f: 14 | supervisor_config = yaml.load(f) 15 | 16 | graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename') 17 | sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(graph_pkl_filename) 18 | 19 | supervisor = DCRNNSupervisor(adj_mx=adj_mx, **supervisor_config) 20 | 21 | supervisor.train() 22 | 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--config_filename', default=None, type=str, 27 | help='Configuration filename for restoring the model.') 28 | parser.add_argument('--use_cpu_only', default=False, type=bool, help='Set to true to only use cpu.') 29 | args = parser.parse_args() 30 | main(args) 31 | -------------------------------------------------------------------------------- /figures/model_architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/figures/model_architecture.jpg -------------------------------------------------------------------------------- /figures/result1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/figures/result1.png -------------------------------------------------------------------------------- /figures/result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/figures/result2.png -------------------------------------------------------------------------------- /figures/result3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/figures/result3.png -------------------------------------------------------------------------------- /figures/result4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/figures/result4.png -------------------------------------------------------------------------------- /lib/AMSGrad.py: -------------------------------------------------------------------------------- 1 | """AMSGrad for TensorFlow. 2 | From: https://github.com/taki0112/AMSGrad-Tensorflow 3 | """ 4 | 5 | from tensorflow.python.eager import context 6 | from tensorflow.python.framework import ops 7 | from tensorflow.python.ops import control_flow_ops 8 | from tensorflow.python.ops import math_ops 9 | from tensorflow.python.ops import resource_variable_ops 10 | from tensorflow.python.ops import state_ops 11 | from tensorflow.python.ops import variable_scope 12 | from tensorflow.python.training import optimizer 13 | 14 | 15 | class AMSGrad(optimizer.Optimizer): 16 | def __init__(self, learning_rate=0.01, beta1=0.9, beta2=0.99, epsilon=1e-8, use_locking=False, name="AMSGrad"): 17 | super(AMSGrad, self).__init__(use_locking, name) 18 | self._lr = learning_rate 19 | self._beta1 = beta1 20 | self._beta2 = beta2 21 | self._epsilon = epsilon 22 | 23 | self._lr_t = None 24 | self._beta1_t = None 25 | self._beta2_t = None 26 | self._epsilon_t = None 27 | 28 | self._beta1_power = None 29 | self._beta2_power = None 30 | 31 | def _create_slots(self, var_list): 32 | first_var = min(var_list, key=lambda x: x.name) 33 | 34 | create_new = self._beta1_power is None 35 | if not create_new and context.in_graph_mode(): 36 | create_new = (self._beta1_power.graph is not first_var.graph) 37 | 38 | if create_new: 39 | with ops.colocate_with(first_var): 40 | self._beta1_power = variable_scope.variable(self._beta1, name="beta1_power", trainable=False) 41 | self._beta2_power = variable_scope.variable(self._beta2, name="beta2_power", trainable=False) 42 | # Create slots for the first and second moments. 43 | for v in var_list: 44 | self._zeros_slot(v, "m", self._name) 45 | self._zeros_slot(v, "v", self._name) 46 | self._zeros_slot(v, "vhat", self._name) 47 | 48 | def _prepare(self): 49 | self._lr_t = ops.convert_to_tensor(self._lr) 50 | self._beta1_t = ops.convert_to_tensor(self._beta1) 51 | self._beta2_t = ops.convert_to_tensor(self._beta2) 52 | self._epsilon_t = ops.convert_to_tensor(self._epsilon) 53 | 54 | def _apply_dense(self, grad, var): 55 | beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) 56 | beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) 57 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 58 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 59 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 60 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 61 | 62 | lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) 63 | 64 | # m_t = beta1 * m + (1 - beta1) * g_t 65 | m = self.get_slot(var, "m") 66 | m_scaled_g_values = grad * (1 - beta1_t) 67 | m_t = state_ops.assign(m, beta1_t * m + m_scaled_g_values, use_locking=self._use_locking) 68 | 69 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 70 | v = self.get_slot(var, "v") 71 | v_scaled_g_values = (grad * grad) * (1 - beta2_t) 72 | v_t = state_ops.assign(v, beta2_t * v + v_scaled_g_values, use_locking=self._use_locking) 73 | 74 | # amsgrad 75 | vhat = self.get_slot(var, "vhat") 76 | vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat)) 77 | v_sqrt = math_ops.sqrt(vhat_t) 78 | 79 | var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) 80 | return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t]) 81 | 82 | def _resource_apply_dense(self, grad, var): 83 | var = var.handle 84 | beta1_power = math_ops.cast(self._beta1_power, grad.dtype.base_dtype) 85 | beta2_power = math_ops.cast(self._beta2_power, grad.dtype.base_dtype) 86 | lr_t = math_ops.cast(self._lr_t, grad.dtype.base_dtype) 87 | beta1_t = math_ops.cast(self._beta1_t, grad.dtype.base_dtype) 88 | beta2_t = math_ops.cast(self._beta2_t, grad.dtype.base_dtype) 89 | epsilon_t = math_ops.cast(self._epsilon_t, grad.dtype.base_dtype) 90 | 91 | lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) 92 | 93 | # m_t = beta1 * m + (1 - beta1) * g_t 94 | m = self.get_slot(var, "m").handle 95 | m_scaled_g_values = grad * (1 - beta1_t) 96 | m_t = state_ops.assign(m, beta1_t * m + m_scaled_g_values, use_locking=self._use_locking) 97 | 98 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 99 | v = self.get_slot(var, "v").handle 100 | v_scaled_g_values = (grad * grad) * (1 - beta2_t) 101 | v_t = state_ops.assign(v, beta2_t * v + v_scaled_g_values, use_locking=self._use_locking) 102 | 103 | # amsgrad 104 | vhat = self.get_slot(var, "vhat").handle 105 | vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat)) 106 | v_sqrt = math_ops.sqrt(vhat_t) 107 | 108 | var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) 109 | return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t]) 110 | 111 | def _apply_sparse_shared(self, grad, var, indices, scatter_add): 112 | beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) 113 | beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) 114 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 115 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 116 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 117 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 118 | 119 | lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) 120 | 121 | # m_t = beta1 * m + (1 - beta1) * g_t 122 | m = self.get_slot(var, "m") 123 | m_scaled_g_values = grad * (1 - beta1_t) 124 | m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) 125 | with ops.control_dependencies([m_t]): 126 | m_t = scatter_add(m, indices, m_scaled_g_values) 127 | 128 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 129 | v = self.get_slot(var, "v") 130 | v_scaled_g_values = (grad * grad) * (1 - beta2_t) 131 | v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) 132 | with ops.control_dependencies([v_t]): 133 | v_t = scatter_add(v, indices, v_scaled_g_values) 134 | 135 | # amsgrad 136 | vhat = self.get_slot(var, "vhat") 137 | vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat)) 138 | v_sqrt = math_ops.sqrt(vhat_t) 139 | var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) 140 | return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t]) 141 | 142 | def _apply_sparse(self, grad, var): 143 | return self._apply_sparse_shared( 144 | grad.values, var, grad.indices, 145 | lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda 146 | x, i, v, use_locking=self._use_locking)) 147 | 148 | def _resource_scatter_add(self, x, i, v): 149 | with ops.control_dependencies( 150 | [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): 151 | return x.value() 152 | 153 | def _resource_apply_sparse(self, grad, var, indices): 154 | return self._apply_sparse_shared( 155 | grad, var, indices, self._resource_scatter_add) 156 | 157 | def _finish(self, update_ops, name_scope): 158 | # Update the power accumulators. 159 | with ops.control_dependencies(update_ops): 160 | with ops.colocate_with(self._beta1_power): 161 | update_beta1 = self._beta1_power.assign( 162 | self._beta1_power * self._beta1_t, 163 | use_locking=self._use_locking) 164 | update_beta2 = self._beta2_power.assign( 165 | self._beta2_power * self._beta2_t, 166 | use_locking=self._use_locking) 167 | return control_flow_ops.group(*update_ops + [update_beta1, update_beta2], 168 | name=name_scope) 169 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/lib/__init__.py -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def masked_mse_tf(preds, labels, null_val=np.nan): 6 | """ 7 | Accuracy with masking. 8 | :param preds: 9 | :param labels: 10 | :param null_val: 11 | :return: 12 | """ 13 | if np.isnan(null_val): 14 | mask = ~tf.is_nan(labels) 15 | else: 16 | mask = tf.not_equal(labels, null_val) 17 | mask = tf.cast(mask, tf.float32) 18 | mask /= tf.reduce_mean(mask) 19 | mask = tf.where(tf.is_nan(mask), tf.zeros_like(mask), mask) 20 | loss = tf.square(tf.subtract(preds, labels)) 21 | loss = loss * mask 22 | loss = tf.where(tf.is_nan(loss), tf.zeros_like(loss), loss) 23 | return tf.reduce_mean(loss) 24 | 25 | 26 | def masked_mae_tf(preds, labels, null_val=np.nan): 27 | """ 28 | Accuracy with masking. 29 | :param preds: 30 | :param labels: 31 | :param null_val: 32 | :return: 33 | """ 34 | if np.isnan(null_val): 35 | mask = ~tf.is_nan(labels) 36 | else: 37 | mask = tf.not_equal(labels, null_val) 38 | mask = tf.cast(mask, tf.float32) 39 | mask /= tf.reduce_mean(mask) 40 | mask = tf.where(tf.is_nan(mask), tf.zeros_like(mask), mask) 41 | loss = tf.abs(tf.subtract(preds, labels)) 42 | loss = loss * mask 43 | loss = tf.where(tf.is_nan(loss), tf.zeros_like(loss), loss) 44 | return tf.reduce_mean(loss) 45 | 46 | 47 | def masked_rmse_tf(preds, labels, null_val=np.nan): 48 | """ 49 | Accuracy with masking. 50 | :param preds: 51 | :param labels: 52 | :param null_val: 53 | :return: 54 | """ 55 | return tf.sqrt(masked_mse_tf(preds=preds, labels=labels, null_val=null_val)) 56 | 57 | 58 | def masked_rmse_np(preds, labels, null_val=np.nan): 59 | return np.sqrt(masked_mse_np(preds=preds, labels=labels, null_val=null_val)) 60 | 61 | 62 | def masked_mse_np(preds, labels, null_val=np.nan): 63 | with np.errstate(divide='ignore', invalid='ignore'): 64 | if np.isnan(null_val): 65 | mask = ~np.isnan(labels) 66 | else: 67 | mask = np.not_equal(labels, null_val) 68 | mask = mask.astype('float32') 69 | mask /= np.mean(mask) 70 | rmse = np.square(np.subtract(preds, labels)).astype('float32') 71 | rmse = np.nan_to_num(rmse * mask) 72 | return np.mean(rmse) 73 | 74 | 75 | def masked_mae_np(preds, labels, null_val=np.nan): 76 | with np.errstate(divide='ignore', invalid='ignore'): 77 | if np.isnan(null_val): 78 | mask = ~np.isnan(labels) 79 | else: 80 | mask = np.not_equal(labels, null_val) 81 | mask = mask.astype('float32') 82 | mask /= np.mean(mask) 83 | mae = np.abs(np.subtract(preds, labels)).astype('float32') 84 | mae = np.nan_to_num(mae * mask) 85 | return np.mean(mae) 86 | 87 | 88 | def masked_mape_np(preds, labels, null_val=np.nan): 89 | with np.errstate(divide='ignore', invalid='ignore'): 90 | if np.isnan(null_val): 91 | mask = ~np.isnan(labels) 92 | else: 93 | mask = np.not_equal(labels, null_val) 94 | mask = mask.astype('float32') 95 | mask /= np.mean(mask) 96 | mape = np.abs(np.divide(np.subtract(preds, labels).astype('float32'), labels)) 97 | mape = np.nan_to_num(mask * mape) 98 | return np.mean(mape) 99 | 100 | 101 | # Builds loss function. 102 | def masked_mse_loss(scaler, null_val): 103 | def loss(preds, labels): 104 | if scaler: 105 | preds = scaler.inverse_transform(preds) 106 | labels = scaler.inverse_transform(labels) 107 | return masked_mse_tf(preds=preds, labels=labels, null_val=null_val) 108 | 109 | return loss 110 | 111 | 112 | def masked_rmse_loss(scaler, null_val): 113 | def loss(preds, labels): 114 | if scaler: 115 | preds = scaler.inverse_transform(preds) 116 | labels = scaler.inverse_transform(labels) 117 | return masked_rmse_tf(preds=preds, labels=labels, null_val=null_val) 118 | 119 | return loss 120 | 121 | 122 | def masked_mae_loss(scaler, null_val): 123 | def loss(preds, labels): 124 | if scaler: 125 | preds = scaler.inverse_transform(preds) 126 | labels = scaler.inverse_transform(labels) 127 | mae = masked_mae_tf(preds=preds, labels=labels, null_val=null_val) 128 | return mae 129 | 130 | return loss 131 | 132 | 133 | def calculate_metrics(df_pred, df_test, null_val): 134 | """ 135 | Calculate the MAE, MAPE, RMSE 136 | :param df_pred: 137 | :param df_test: 138 | :param null_val: 139 | :return: 140 | """ 141 | mape = masked_mape_np(preds=df_pred.as_matrix(), labels=df_test.as_matrix(), null_val=null_val) 142 | mae = masked_mae_np(preds=df_pred.as_matrix(), labels=df_test.as_matrix(), null_val=null_val) 143 | rmse = masked_rmse_np(preds=df_pred.as_matrix(), labels=df_test.as_matrix(), null_val=null_val) 144 | return mae, mape, rmse -------------------------------------------------------------------------------- /lib/metrics_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from lib import metrics 7 | 8 | 9 | class MyTestCase(unittest.TestCase): 10 | def test_masked_mape_np(self): 11 | preds = np.array([ 12 | [1, 2, 2], 13 | [3, 4, 5], 14 | ], dtype=np.float32) 15 | labels = np.array([ 16 | [1, 2, 2], 17 | [3, 4, 4] 18 | ], dtype=np.float32) 19 | mape = metrics.masked_mape_np(preds=preds, labels=labels) 20 | self.assertAlmostEqual(1 / 24.0, mape, delta=1e-5) 21 | 22 | def test_masked_mape_np2(self): 23 | preds = np.array([ 24 | [1, 2, 2], 25 | [3, 4, 5], 26 | ], dtype=np.float32) 27 | labels = np.array([ 28 | [1, 2, 2], 29 | [3, 4, 4] 30 | ], dtype=np.float32) 31 | mape = metrics.masked_mape_np(preds=preds, labels=labels, null_val=4) 32 | self.assertEqual(0., mape) 33 | 34 | def test_masked_mape_np_all_zero(self): 35 | preds = np.array([ 36 | [1, 2], 37 | [3, 4], 38 | ], dtype=np.float32) 39 | labels = np.array([ 40 | [0, 0], 41 | [0, 0] 42 | ], dtype=np.float32) 43 | mape = metrics.masked_mape_np(preds=preds, labels=labels, null_val=0) 44 | self.assertEqual(0., mape) 45 | 46 | def test_masked_mape_np_all_nan(self): 47 | preds = np.array([ 48 | [1, 2], 49 | [3, 4], 50 | ], dtype=np.float32) 51 | labels = np.array([ 52 | [np.nan, np.nan], 53 | [np.nan, np.nan] 54 | ], dtype=np.float32) 55 | mape = metrics.masked_mape_np(preds=preds, labels=labels) 56 | self.assertEqual(0., mape) 57 | 58 | def test_masked_mape_np_nan(self): 59 | preds = np.array([ 60 | [1, 2], 61 | [3, 4], 62 | ], dtype=np.float32) 63 | labels = np.array([ 64 | [np.nan, np.nan], 65 | [np.nan, 3] 66 | ], dtype=np.float32) 67 | mape = metrics.masked_mape_np(preds=preds, labels=labels) 68 | self.assertAlmostEqual(1 / 3., mape, delta=1e-5) 69 | 70 | def test_masked_rmse_np_vanilla(self): 71 | preds = np.array([ 72 | [1, 2], 73 | [3, 4], 74 | ], dtype=np.float32) 75 | labels = np.array([ 76 | [1, 4], 77 | [3, 4] 78 | ], dtype=np.float32) 79 | mape = metrics.masked_rmse_np(preds=preds, labels=labels, null_val=0) 80 | self.assertEqual(1., mape) 81 | 82 | def test_masked_rmse_np_nan(self): 83 | preds = np.array([ 84 | [1, 2], 85 | [3, 4], 86 | ], dtype=np.float32) 87 | labels = np.array([ 88 | [1, np.nan], 89 | [3, 4] 90 | ], dtype=np.float32) 91 | rmse = metrics.masked_rmse_np(preds=preds, labels=labels) 92 | self.assertEqual(0., rmse) 93 | 94 | def test_masked_rmse_np_all_zero(self): 95 | preds = np.array([ 96 | [1, 2], 97 | [3, 4], 98 | ], dtype=np.float32) 99 | labels = np.array([ 100 | [0, 0], 101 | [0, 0] 102 | ], dtype=np.float32) 103 | mape = metrics.masked_rmse_np(preds=preds, labels=labels, null_val=0) 104 | self.assertEqual(0., mape) 105 | 106 | def test_masked_rmse_np_missing(self): 107 | preds = np.array([ 108 | [1, 2], 109 | [3, 4], 110 | ], dtype=np.float32) 111 | labels = np.array([ 112 | [1, 0], 113 | [3, 4] 114 | ], dtype=np.float32) 115 | mape = metrics.masked_rmse_np(preds=preds, labels=labels, null_val=0) 116 | self.assertEqual(0., mape) 117 | 118 | def test_masked_rmse_np2(self): 119 | preds = np.array([ 120 | [1, 2], 121 | [3, 4], 122 | ], dtype=np.float32) 123 | labels = np.array([ 124 | [1, 0], 125 | [3, 3] 126 | ], dtype=np.float32) 127 | rmse = metrics.masked_rmse_np(preds=preds, labels=labels, null_val=0) 128 | self.assertAlmostEqual(np.sqrt(1 / 3.), rmse, delta=1e-5) 129 | 130 | 131 | class TFRMSETestCase(unittest.TestCase): 132 | def test_masked_mse_null(self): 133 | with tf.Session() as sess: 134 | preds = tf.constant(np.array([ 135 | [1, 2], 136 | [3, 4], 137 | ], dtype=np.float32)) 138 | labels = tf.constant(np.array([ 139 | [1, 0], 140 | [3, 3] 141 | ], dtype=np.float32)) 142 | rmse = metrics.masked_mse_tf(preds=preds, labels=labels, null_val=0) 143 | self.assertAlmostEqual(1 / 3.0, sess.run(rmse), delta=1e-5) 144 | 145 | def test_masked_mse_vanilla(self): 146 | with tf.Session() as sess: 147 | preds = tf.constant(np.array([ 148 | [1, 2], 149 | [3, 4], 150 | ], dtype=np.float32)) 151 | labels = tf.constant(np.array([ 152 | [1, 0], 153 | [3, 3] 154 | ], dtype=np.float32)) 155 | rmse = metrics.masked_mse_tf(preds=preds, labels=labels) 156 | self.assertAlmostEqual(1.25, sess.run(rmse), delta=1e-5) 157 | 158 | def test_masked_mse_all_zero(self): 159 | with tf.Session() as sess: 160 | preds = tf.constant(np.array([ 161 | [1, 2], 162 | [3, 4], 163 | ], dtype=np.float32)) 164 | labels = tf.constant(np.array([ 165 | [0, 0], 166 | [0, 0] 167 | ], dtype=np.float32)) 168 | rmse = metrics.masked_mse_tf(preds=preds, labels=labels, null_val=0) 169 | self.assertAlmostEqual(0., sess.run(rmse), delta=1e-5) 170 | 171 | def test_masked_mse_nan(self): 172 | with tf.Session() as sess: 173 | preds = tf.constant(np.array([ 174 | [1, 2], 175 | [3, 4], 176 | ], dtype=np.float32)) 177 | labels = tf.constant(np.array([ 178 | [1, 2], 179 | [3, np.nan] 180 | ], dtype=np.float32)) 181 | rmse = metrics.masked_mse_tf(preds=preds, labels=labels) 182 | self.assertAlmostEqual(0., sess.run(rmse), delta=1e-5) 183 | 184 | def test_masked_mse_all_nan(self): 185 | with tf.Session() as sess: 186 | preds = tf.constant(np.array([ 187 | [1, 2], 188 | [3, 4], 189 | ], dtype=np.float32)) 190 | labels = tf.constant(np.array([ 191 | [np.nan, np.nan], 192 | [np.nan, np.nan] 193 | ], dtype=np.float32)) 194 | rmse = metrics.masked_mse_tf(preds=preds, labels=labels, null_val=0) 195 | self.assertAlmostEqual(0., sess.run(rmse), delta=1e-5) 196 | 197 | if __name__ == '__main__': 198 | unittest.main() 199 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import os 4 | import pickle 5 | import scipy.sparse as sp 6 | import sys 7 | import tensorflow as tf 8 | 9 | from scipy.sparse import linalg 10 | 11 | 12 | class DataLoader(object): 13 | def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False): 14 | """ 15 | 16 | :param xs: 17 | :param ys: 18 | :param batch_size: 19 | :param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size. 20 | """ 21 | self.batch_size = batch_size 22 | self.current_ind = 0 23 | if pad_with_last_sample: 24 | num_padding = (batch_size - (len(xs) % batch_size)) % batch_size 25 | x_padding = np.repeat(xs[-1:], num_padding, axis=0) 26 | y_padding = np.repeat(ys[-1:], num_padding, axis=0) 27 | xs = np.concatenate([xs, x_padding], axis=0) 28 | ys = np.concatenate([ys, y_padding], axis=0) 29 | self.size = len(xs) 30 | self.num_batch = int(self.size // self.batch_size) 31 | if shuffle: 32 | permutation = np.random.permutation(self.size) 33 | xs, ys = xs[permutation], ys[permutation] 34 | self.xs = xs 35 | self.ys = ys 36 | 37 | def get_iterator(self): 38 | self.current_ind = 0 39 | 40 | def _wrapper(): 41 | while self.current_ind < self.num_batch: 42 | start_ind = self.batch_size * self.current_ind 43 | end_ind = min(self.size, self.batch_size * (self.current_ind + 1)) 44 | x_i = self.xs[start_ind: end_ind, ...] 45 | y_i = self.ys[start_ind: end_ind, ...] 46 | yield (x_i, y_i) 47 | self.current_ind += 1 48 | 49 | return _wrapper() 50 | 51 | 52 | class StandardScaler: 53 | """ 54 | Standard the input 55 | """ 56 | 57 | def __init__(self, mean, std): 58 | self.mean = mean 59 | self.std = std 60 | 61 | def transform(self, data): 62 | return (data - self.mean) / self.std 63 | 64 | def inverse_transform(self, data): 65 | return (data * self.std) + self.mean 66 | 67 | 68 | def add_simple_summary(writer, names, values, global_step): 69 | """ 70 | Writes summary for a list of scalars. 71 | :param writer: 72 | :param names: 73 | :param values: 74 | :param global_step: 75 | :return: 76 | """ 77 | for name, value in zip(names, values): 78 | summary = tf.Summary() 79 | summary_value = summary.value.add() 80 | summary_value.simple_value = value 81 | summary_value.tag = name 82 | writer.add_summary(summary, global_step) 83 | 84 | 85 | def calculate_normalized_laplacian(adj): 86 | """ 87 | # L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2 88 | # D = diag(A 1) 89 | :param adj: 90 | :return: 91 | """ 92 | adj = sp.coo_matrix(adj) 93 | d = np.array(adj.sum(1)) 94 | d_inv_sqrt = np.power(d, -0.5).flatten() 95 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 96 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 97 | normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 98 | return normalized_laplacian 99 | 100 | 101 | def calculate_random_walk_matrix(adj_mx): 102 | adj_mx = sp.coo_matrix(adj_mx) 103 | d = np.array(adj_mx.sum(1)) 104 | d_inv = np.power(d, -1).flatten() 105 | d_inv[np.isinf(d_inv)] = 0. 106 | d_mat_inv = sp.diags(d_inv) 107 | random_walk_mx = d_mat_inv.dot(adj_mx).tocoo() 108 | return random_walk_mx 109 | 110 | 111 | def calculate_reverse_random_walk_matrix(adj_mx): 112 | return calculate_random_walk_matrix(np.transpose(adj_mx)) 113 | 114 | 115 | def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True): 116 | if undirected: 117 | adj_mx = np.maximum.reduce([adj_mx, adj_mx.T]) 118 | L = calculate_normalized_laplacian(adj_mx) 119 | if lambda_max is None: 120 | lambda_max, _ = linalg.eigsh(L, 1, which='LM') 121 | lambda_max = lambda_max[0] 122 | L = sp.csr_matrix(L) 123 | M, _ = L.shape 124 | I = sp.identity(M, format='csr', dtype=L.dtype) 125 | L = (2 / lambda_max * L) - I 126 | return L.astype(np.float32) 127 | 128 | 129 | def config_logging(log_dir, log_filename='info.log', level=logging.INFO): 130 | # Add file handler and stdout handler 131 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 132 | # Create the log directory if necessary. 133 | try: 134 | os.makedirs(log_dir) 135 | except OSError: 136 | pass 137 | file_handler = logging.FileHandler(os.path.join(log_dir, log_filename)) 138 | file_handler.setFormatter(formatter) 139 | file_handler.setLevel(level=level) 140 | # Add console handler. 141 | console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 142 | console_handler = logging.StreamHandler(sys.stdout) 143 | console_handler.setFormatter(console_formatter) 144 | console_handler.setLevel(level=level) 145 | logging.basicConfig(handlers=[file_handler, console_handler], level=level) 146 | 147 | 148 | def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO): 149 | logger = logging.getLogger(name) 150 | logger.setLevel(level) 151 | # Add file handler and stdout handler 152 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 153 | file_handler = logging.FileHandler(os.path.join(log_dir, log_filename)) 154 | file_handler.setFormatter(formatter) 155 | # Add console handler. 156 | console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 157 | console_handler = logging.StreamHandler(sys.stdout) 158 | console_handler.setFormatter(console_formatter) 159 | logger.addHandler(file_handler) 160 | logger.addHandler(console_handler) 161 | # Add google cloud log handler 162 | logger.info('Log directory: %s', log_dir) 163 | return logger 164 | 165 | 166 | def get_total_trainable_parameter_size(): 167 | """ 168 | Calculates the total number of trainable parameters in the current graph. 169 | :return: 170 | """ 171 | total_parameters = 0 172 | for variable in tf.trainable_variables(): 173 | # shape is an array of tf.Dimension 174 | total_parameters += np.product([x.value for x in variable.get_shape()]) 175 | return total_parameters 176 | 177 | 178 | def load_dataset(dataset_dir, batch_size, test_batch_size=None, **kwargs): 179 | data = {} 180 | for category in ['train', 'val', 'test']: 181 | cat_data = np.load(os.path.join(dataset_dir, category + '.npz')) 182 | data['x_' + category] = cat_data['x'] 183 | data['y_' + category] = cat_data['y'] 184 | scaler = StandardScaler(mean=data['x_train'][..., 0].mean(), std=data['x_train'][..., 0].std()) 185 | # Data format 186 | for category in ['train', 'val', 'test']: 187 | data['x_' + category][..., 0] = scaler.transform(data['x_' + category][..., 0]) 188 | data['y_' + category][..., 0] = scaler.transform(data['y_' + category][..., 0]) 189 | data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True) 190 | data['val_loader'] = DataLoader(data['x_val'], data['y_val'], test_batch_size, shuffle=False) 191 | data['test_loader'] = DataLoader(data['x_test'], data['y_test'], test_batch_size, shuffle=False) 192 | data['scaler'] = scaler 193 | 194 | return data 195 | 196 | 197 | def load_graph_data(pkl_filename): 198 | sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename) 199 | return sensor_ids, sensor_id_to_ind, adj_mx 200 | 201 | 202 | def load_pickle(pickle_file): 203 | try: 204 | with open(pickle_file, 'rb') as f: 205 | pickle_data = pickle.load(f) 206 | except UnicodeDecodeError as e: 207 | with open(pickle_file, 'rb') as f: 208 | pickle_data = pickle.load(f, encoding='latin1') 209 | except Exception as e: 210 | print('Unable to load data ', pickle_file, ':', e) 211 | raise 212 | return pickle_data 213 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/model/__init__.py -------------------------------------------------------------------------------- /model/pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/model/pytorch/__init__.py -------------------------------------------------------------------------------- /model/pytorch/dcrnn_cell.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from lib import utils 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | 9 | class LayerParams: 10 | def __init__(self, rnn_network: torch.nn.Module, layer_type: str): 11 | self._rnn_network = rnn_network 12 | self._params_dict = {} 13 | self._biases_dict = {} 14 | self._type = layer_type 15 | 16 | def get_weights(self, shape): 17 | if shape not in self._params_dict: 18 | nn_param = torch.nn.Parameter(torch.empty(*shape, device=device)) 19 | torch.nn.init.xavier_normal_(nn_param) 20 | self._params_dict[shape] = nn_param 21 | self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)), 22 | nn_param) 23 | return self._params_dict[shape] 24 | 25 | def get_biases(self, length, bias_start=0.0): 26 | if length not in self._biases_dict: 27 | biases = torch.nn.Parameter(torch.empty(length, device=device)) 28 | torch.nn.init.constant_(biases, bias_start) 29 | self._biases_dict[length] = biases 30 | self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)), 31 | biases) 32 | 33 | return self._biases_dict[length] 34 | 35 | 36 | class DCGRUCell(torch.nn.Module): 37 | def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, nonlinearity='tanh', 38 | filter_type="laplacian", use_gc_for_ru=True): 39 | """ 40 | 41 | :param num_units: 42 | :param adj_mx: 43 | :param max_diffusion_step: 44 | :param num_nodes: 45 | :param nonlinearity: 46 | :param filter_type: "laplacian", "random_walk", "dual_random_walk". 47 | :param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates. 48 | """ 49 | 50 | super().__init__() 51 | self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu 52 | # support other nonlinearities up here? 53 | self._num_nodes = num_nodes 54 | self._num_units = num_units 55 | self._max_diffusion_step = max_diffusion_step 56 | self._supports = [] 57 | self._use_gc_for_ru = use_gc_for_ru 58 | supports = [] 59 | if filter_type == "laplacian": 60 | supports.append(utils.calculate_scaled_laplacian(adj_mx, lambda_max=None)) 61 | elif filter_type == "random_walk": 62 | supports.append(utils.calculate_random_walk_matrix(adj_mx).T) 63 | elif filter_type == "dual_random_walk": 64 | supports.append(utils.calculate_random_walk_matrix(adj_mx).T) 65 | supports.append(utils.calculate_random_walk_matrix(adj_mx.T).T) 66 | else: 67 | supports.append(utils.calculate_scaled_laplacian(adj_mx)) 68 | for support in supports: 69 | self._supports.append(self._build_sparse_matrix(support)) 70 | 71 | self._fc_params = LayerParams(self, 'fc') 72 | self._gconv_params = LayerParams(self, 'gconv') 73 | 74 | @staticmethod 75 | def _build_sparse_matrix(L): 76 | L = L.tocoo() 77 | indices = np.column_stack((L.row, L.col)) 78 | # this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L) 79 | indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))] 80 | L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device) 81 | return L 82 | 83 | def forward(self, inputs, hx): 84 | """Gated recurrent unit (GRU) with Graph Convolution. 85 | :param inputs: (B, num_nodes * input_dim) 86 | :param hx: (B, num_nodes * rnn_units) 87 | 88 | :return 89 | - Output: A `2-D` tensor with shape `(B, num_nodes * rnn_units)`. 90 | """ 91 | output_size = 2 * self._num_units 92 | if self._use_gc_for_ru: 93 | fn = self._gconv 94 | else: 95 | fn = self._fc 96 | value = torch.sigmoid(fn(inputs, hx, output_size, bias_start=1.0)) 97 | value = torch.reshape(value, (-1, self._num_nodes, output_size)) 98 | r, u = torch.split(tensor=value, split_size_or_sections=self._num_units, dim=-1) 99 | r = torch.reshape(r, (-1, self._num_nodes * self._num_units)) 100 | u = torch.reshape(u, (-1, self._num_nodes * self._num_units)) 101 | 102 | c = self._gconv(inputs, r * hx, self._num_units) 103 | if self._activation is not None: 104 | c = self._activation(c) 105 | 106 | new_state = u * hx + (1.0 - u) * c 107 | return new_state 108 | 109 | @staticmethod 110 | def _concat(x, x_): 111 | x_ = x_.unsqueeze(0) 112 | return torch.cat([x, x_], dim=0) 113 | 114 | def _fc(self, inputs, state, output_size, bias_start=0.0): 115 | batch_size = inputs.shape[0] 116 | inputs = torch.reshape(inputs, (batch_size * self._num_nodes, -1)) 117 | state = torch.reshape(state, (batch_size * self._num_nodes, -1)) 118 | inputs_and_state = torch.cat([inputs, state], dim=-1) 119 | input_size = inputs_and_state.shape[-1] 120 | weights = self._fc_params.get_weights((input_size, output_size)) 121 | value = torch.sigmoid(torch.matmul(inputs_and_state, weights)) 122 | biases = self._fc_params.get_biases(output_size, bias_start) 123 | value += biases 124 | return value 125 | 126 | def _gconv(self, inputs, state, output_size, bias_start=0.0): 127 | # Reshape input and state to (batch_size, num_nodes, input_dim/state_dim) 128 | batch_size = inputs.shape[0] 129 | inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1)) 130 | state = torch.reshape(state, (batch_size, self._num_nodes, -1)) 131 | inputs_and_state = torch.cat([inputs, state], dim=2) 132 | input_size = inputs_and_state.size(2) 133 | 134 | x = inputs_and_state 135 | x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size) 136 | x0 = torch.reshape(x0, shape=[self._num_nodes, input_size * batch_size]) 137 | x = torch.unsqueeze(x0, 0) 138 | 139 | if self._max_diffusion_step == 0: 140 | pass 141 | else: 142 | for support in self._supports: 143 | x1 = torch.sparse.mm(support, x0) 144 | x = self._concat(x, x1) 145 | 146 | for k in range(2, self._max_diffusion_step + 1): 147 | x2 = 2 * torch.sparse.mm(support, x1) - x0 148 | x = self._concat(x, x2) 149 | x1, x0 = x2, x1 150 | 151 | num_matrices = len(self._supports) * self._max_diffusion_step + 1 # Adds for x itself. 152 | x = torch.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size]) 153 | x = x.permute(3, 1, 2, 0) # (batch_size, num_nodes, input_size, order) 154 | x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices]) 155 | 156 | weights = self._gconv_params.get_weights((input_size * num_matrices, output_size)) 157 | x = torch.matmul(x, weights) # (batch_size * self._num_nodes, output_size) 158 | 159 | biases = self._gconv_params.get_biases(output_size, bias_start) 160 | x += biases 161 | # Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim) 162 | return torch.reshape(x, [batch_size, self._num_nodes * output_size]) 163 | -------------------------------------------------------------------------------- /model/pytorch/dcrnn_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from model.pytorch.dcrnn_cell import DCGRUCell 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | def count_parameters(model): 11 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 12 | 13 | 14 | class Seq2SeqAttrs: 15 | def __init__(self, adj_mx, **model_kwargs): 16 | self.adj_mx = adj_mx 17 | self.max_diffusion_step = int(model_kwargs.get('max_diffusion_step', 2)) 18 | self.cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000)) 19 | self.filter_type = model_kwargs.get('filter_type', 'laplacian') 20 | self.num_nodes = int(model_kwargs.get('num_nodes', 1)) 21 | self.num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1)) 22 | self.rnn_units = int(model_kwargs.get('rnn_units')) 23 | self.hidden_state_size = self.num_nodes * self.rnn_units 24 | 25 | 26 | class EncoderModel(nn.Module, Seq2SeqAttrs): 27 | def __init__(self, adj_mx, **model_kwargs): 28 | nn.Module.__init__(self) 29 | Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs) 30 | self.input_dim = int(model_kwargs.get('input_dim', 1)) 31 | self.seq_len = int(model_kwargs.get('seq_len')) # for the encoder 32 | self.dcgru_layers = nn.ModuleList( 33 | [DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes, 34 | filter_type=self.filter_type) for _ in range(self.num_rnn_layers)]) 35 | 36 | def forward(self, inputs, hidden_state=None): 37 | """ 38 | Encoder forward pass. 39 | 40 | :param inputs: shape (batch_size, self.num_nodes * self.input_dim) 41 | :param hidden_state: (num_layers, batch_size, self.hidden_state_size) 42 | optional, zeros if not provided 43 | :return: output: # shape (batch_size, self.hidden_state_size) 44 | hidden_state # shape (num_layers, batch_size, self.hidden_state_size) 45 | (lower indices mean lower layers) 46 | """ 47 | batch_size, _ = inputs.size() 48 | if hidden_state is None: 49 | hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size), 50 | device=device) 51 | hidden_states = [] 52 | output = inputs 53 | for layer_num, dcgru_layer in enumerate(self.dcgru_layers): 54 | next_hidden_state = dcgru_layer(output, hidden_state[layer_num]) 55 | hidden_states.append(next_hidden_state) 56 | output = next_hidden_state 57 | 58 | return output, torch.stack(hidden_states) # runs in O(num_layers) so not too slow 59 | 60 | 61 | class DecoderModel(nn.Module, Seq2SeqAttrs): 62 | def __init__(self, adj_mx, **model_kwargs): 63 | # super().__init__(is_training, adj_mx, **model_kwargs) 64 | nn.Module.__init__(self) 65 | Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs) 66 | self.output_dim = int(model_kwargs.get('output_dim', 1)) 67 | self.horizon = int(model_kwargs.get('horizon', 1)) # for the decoder 68 | self.projection_layer = nn.Linear(self.rnn_units, self.output_dim) 69 | self.dcgru_layers = nn.ModuleList( 70 | [DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes, 71 | filter_type=self.filter_type) for _ in range(self.num_rnn_layers)]) 72 | 73 | def forward(self, inputs, hidden_state=None): 74 | """ 75 | Decoder forward pass. 76 | 77 | :param inputs: shape (batch_size, self.num_nodes * self.output_dim) 78 | :param hidden_state: (num_layers, batch_size, self.hidden_state_size) 79 | optional, zeros if not provided 80 | :return: output: # shape (batch_size, self.num_nodes * self.output_dim) 81 | hidden_state # shape (num_layers, batch_size, self.hidden_state_size) 82 | (lower indices mean lower layers) 83 | """ 84 | hidden_states = [] 85 | output = inputs 86 | for layer_num, dcgru_layer in enumerate(self.dcgru_layers): 87 | next_hidden_state = dcgru_layer(output, hidden_state[layer_num]) 88 | hidden_states.append(next_hidden_state) 89 | output = next_hidden_state 90 | 91 | projected = self.projection_layer(output.view(-1, self.rnn_units)) 92 | output = projected.view(-1, self.num_nodes * self.output_dim) 93 | 94 | return output, torch.stack(hidden_states) 95 | 96 | 97 | class DCRNNModel(nn.Module, Seq2SeqAttrs): 98 | def __init__(self, adj_mx, logger, **model_kwargs): 99 | super().__init__() 100 | Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs) 101 | self.encoder_model = EncoderModel(adj_mx, **model_kwargs) 102 | self.decoder_model = DecoderModel(adj_mx, **model_kwargs) 103 | self.cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000)) 104 | self.use_curriculum_learning = bool(model_kwargs.get('use_curriculum_learning', False)) 105 | self._logger = logger 106 | 107 | def _compute_sampling_threshold(self, batches_seen): 108 | return self.cl_decay_steps / ( 109 | self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps)) 110 | 111 | def encoder(self, inputs): 112 | """ 113 | encoder forward pass on t time steps 114 | :param inputs: shape (seq_len, batch_size, num_sensor * input_dim) 115 | :return: encoder_hidden_state: (num_layers, batch_size, self.hidden_state_size) 116 | """ 117 | encoder_hidden_state = None 118 | for t in range(self.encoder_model.seq_len): 119 | _, encoder_hidden_state = self.encoder_model(inputs[t], encoder_hidden_state) 120 | 121 | return encoder_hidden_state 122 | 123 | def decoder(self, encoder_hidden_state, labels=None, batches_seen=None): 124 | """ 125 | Decoder forward pass 126 | :param encoder_hidden_state: (num_layers, batch_size, self.hidden_state_size) 127 | :param labels: (self.horizon, batch_size, self.num_nodes * self.output_dim) [optional, not exist for inference] 128 | :param batches_seen: global step [optional, not exist for inference] 129 | :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) 130 | """ 131 | batch_size = encoder_hidden_state.size(1) 132 | go_symbol = torch.zeros((batch_size, self.num_nodes * self.decoder_model.output_dim), 133 | device=device) 134 | decoder_hidden_state = encoder_hidden_state 135 | decoder_input = go_symbol 136 | 137 | outputs = [] 138 | 139 | for t in range(self.decoder_model.horizon): 140 | decoder_output, decoder_hidden_state = self.decoder_model(decoder_input, 141 | decoder_hidden_state) 142 | decoder_input = decoder_output 143 | outputs.append(decoder_output) 144 | if self.training and self.use_curriculum_learning: 145 | c = np.random.uniform(0, 1) 146 | if c < self._compute_sampling_threshold(batches_seen): 147 | decoder_input = labels[t] 148 | outputs = torch.stack(outputs) 149 | return outputs 150 | 151 | def forward(self, inputs, labels=None, batches_seen=None): 152 | """ 153 | seq2seq forward pass 154 | :param inputs: shape (seq_len, batch_size, num_sensor * input_dim) 155 | :param labels: shape (horizon, batch_size, num_sensor * output) 156 | :param batches_seen: batches seen till now 157 | :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) 158 | """ 159 | encoder_hidden_state = self.encoder(inputs) 160 | self._logger.debug("Encoder complete, starting decoder") 161 | outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen) 162 | self._logger.debug("Decoder complete") 163 | if batches_seen == 0: 164 | self._logger.info( 165 | "Total trainable parameters {}".format(count_parameters(self)) 166 | ) 167 | return outputs 168 | -------------------------------------------------------------------------------- /model/pytorch/dcrnn_supervisor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | from lib import utils 9 | from model.pytorch.dcrnn_model import DCRNNModel 10 | from model.pytorch.loss import masked_mae_loss 11 | 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | class DCRNNSupervisor: 16 | def __init__(self, adj_mx, **kwargs): 17 | self._kwargs = kwargs 18 | self._data_kwargs = kwargs.get('data') 19 | self._model_kwargs = kwargs.get('model') 20 | self._train_kwargs = kwargs.get('train') 21 | 22 | self.max_grad_norm = self._train_kwargs.get('max_grad_norm', 1.) 23 | 24 | # logging. 25 | self._log_dir = self._get_log_dir(kwargs) 26 | self._writer = SummaryWriter('runs/' + self._log_dir) 27 | 28 | log_level = self._kwargs.get('log_level', 'INFO') 29 | self._logger = utils.get_logger(self._log_dir, __name__, 'info.log', level=log_level) 30 | 31 | # data set 32 | self._data = utils.load_dataset(**self._data_kwargs) 33 | self.standard_scaler = self._data['scaler'] 34 | 35 | self.num_nodes = int(self._model_kwargs.get('num_nodes', 1)) 36 | self.input_dim = int(self._model_kwargs.get('input_dim', 1)) 37 | self.seq_len = int(self._model_kwargs.get('seq_len')) # for the encoder 38 | self.output_dim = int(self._model_kwargs.get('output_dim', 1)) 39 | self.use_curriculum_learning = bool( 40 | self._model_kwargs.get('use_curriculum_learning', False)) 41 | self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder 42 | 43 | # setup model 44 | dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs) 45 | self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model 46 | self._logger.info("Model created") 47 | 48 | self._epoch_num = self._train_kwargs.get('epoch', 0) 49 | if self._epoch_num > 0: 50 | self.load_model() 51 | 52 | @staticmethod 53 | def _get_log_dir(kwargs): 54 | log_dir = kwargs['train'].get('log_dir') 55 | if log_dir is None: 56 | batch_size = kwargs['data'].get('batch_size') 57 | learning_rate = kwargs['train'].get('base_lr') 58 | max_diffusion_step = kwargs['model'].get('max_diffusion_step') 59 | num_rnn_layers = kwargs['model'].get('num_rnn_layers') 60 | rnn_units = kwargs['model'].get('rnn_units') 61 | structure = '-'.join( 62 | ['%d' % rnn_units for _ in range(num_rnn_layers)]) 63 | horizon = kwargs['model'].get('horizon') 64 | filter_type = kwargs['model'].get('filter_type') 65 | filter_type_abbr = 'L' 66 | if filter_type == 'random_walk': 67 | filter_type_abbr = 'R' 68 | elif filter_type == 'dual_random_walk': 69 | filter_type_abbr = 'DR' 70 | run_id = 'dcrnn_%s_%d_h_%d_%s_lr_%g_bs_%d_%s/' % ( 71 | filter_type_abbr, max_diffusion_step, horizon, 72 | structure, learning_rate, batch_size, 73 | time.strftime('%m%d%H%M%S')) 74 | base_dir = kwargs.get('base_dir') 75 | log_dir = os.path.join(base_dir, run_id) 76 | if not os.path.exists(log_dir): 77 | os.makedirs(log_dir) 78 | return log_dir 79 | 80 | def save_model(self, epoch): 81 | if not os.path.exists('models/'): 82 | os.makedirs('models/') 83 | 84 | config = dict(self._kwargs) 85 | config['model_state_dict'] = self.dcrnn_model.state_dict() 86 | config['epoch'] = epoch 87 | torch.save(config, 'models/epo%d.tar' % epoch) 88 | self._logger.info("Saved model at {}".format(epoch)) 89 | return 'models/epo%d.tar' % epoch 90 | 91 | def load_model(self): 92 | self._setup_graph() 93 | assert os.path.exists('models/epo%d.tar' % self._epoch_num), 'Weights at epoch %d not found' % self._epoch_num 94 | checkpoint = torch.load('models/epo%d.tar' % self._epoch_num, map_location='cpu') 95 | self.dcrnn_model.load_state_dict(checkpoint['model_state_dict']) 96 | self._logger.info("Loaded model at {}".format(self._epoch_num)) 97 | 98 | def _setup_graph(self): 99 | with torch.no_grad(): 100 | self.dcrnn_model = self.dcrnn_model.eval() 101 | 102 | val_iterator = self._data['val_loader'].get_iterator() 103 | 104 | for _, (x, y) in enumerate(val_iterator): 105 | x, y = self._prepare_data(x, y) 106 | output = self.dcrnn_model(x) 107 | break 108 | 109 | def train(self, **kwargs): 110 | kwargs.update(self._train_kwargs) 111 | return self._train(**kwargs) 112 | 113 | def evaluate(self, dataset='val', batches_seen=0): 114 | """ 115 | Computes mean L1Loss 116 | :return: mean L1Loss 117 | """ 118 | with torch.no_grad(): 119 | self.dcrnn_model = self.dcrnn_model.eval() 120 | 121 | val_iterator = self._data['{}_loader'.format(dataset)].get_iterator() 122 | losses = [] 123 | 124 | y_truths = [] 125 | y_preds = [] 126 | 127 | for _, (x, y) in enumerate(val_iterator): 128 | x, y = self._prepare_data(x, y) 129 | 130 | output = self.dcrnn_model(x) 131 | loss = self._compute_loss(y, output) 132 | losses.append(loss.item()) 133 | 134 | y_truths.append(y.cpu()) 135 | y_preds.append(output.cpu()) 136 | 137 | mean_loss = np.mean(losses) 138 | 139 | self._writer.add_scalar('{} loss'.format(dataset), mean_loss, batches_seen) 140 | 141 | y_preds = np.concatenate(y_preds, axis=1) 142 | y_truths = np.concatenate(y_truths, axis=1) # concatenate on batch dimension 143 | 144 | y_truths_scaled = [] 145 | y_preds_scaled = [] 146 | for t in range(y_preds.shape[0]): 147 | y_truth = self.standard_scaler.inverse_transform(y_truths[t]) 148 | y_pred = self.standard_scaler.inverse_transform(y_preds[t]) 149 | y_truths_scaled.append(y_truth) 150 | y_preds_scaled.append(y_pred) 151 | 152 | return mean_loss, {'prediction': y_preds_scaled, 'truth': y_truths_scaled} 153 | 154 | def _train(self, base_lr, 155 | steps, patience=50, epochs=100, lr_decay_ratio=0.1, log_every=1, save_model=1, 156 | test_every_n_epochs=10, epsilon=1e-8, **kwargs): 157 | # steps is used in learning rate - will see if need to use it? 158 | min_val_loss = float('inf') 159 | wait = 0 160 | optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr, eps=epsilon) 161 | 162 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps, 163 | gamma=lr_decay_ratio) 164 | 165 | self._logger.info('Start training ...') 166 | 167 | # this will fail if model is loaded with a changed batch_size 168 | num_batches = self._data['train_loader'].num_batch 169 | self._logger.info("num_batches:{}".format(num_batches)) 170 | 171 | batches_seen = num_batches * self._epoch_num 172 | 173 | for epoch_num in range(self._epoch_num, epochs): 174 | 175 | self.dcrnn_model = self.dcrnn_model.train() 176 | 177 | train_iterator = self._data['train_loader'].get_iterator() 178 | losses = [] 179 | 180 | start_time = time.time() 181 | 182 | for _, (x, y) in enumerate(train_iterator): 183 | optimizer.zero_grad() 184 | 185 | x, y = self._prepare_data(x, y) 186 | 187 | output = self.dcrnn_model(x, y, batches_seen) 188 | 189 | if batches_seen == 0: 190 | # this is a workaround to accommodate dynamically registered parameters in DCGRUCell 191 | optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr, eps=epsilon) 192 | 193 | loss = self._compute_loss(y, output) 194 | 195 | self._logger.debug(loss.item()) 196 | 197 | losses.append(loss.item()) 198 | 199 | batches_seen += 1 200 | loss.backward() 201 | 202 | # gradient clipping - this does it in place 203 | torch.nn.utils.clip_grad_norm_(self.dcrnn_model.parameters(), self.max_grad_norm) 204 | 205 | optimizer.step() 206 | self._logger.info("epoch complete") 207 | lr_scheduler.step() 208 | self._logger.info("evaluating now!") 209 | 210 | val_loss, _ = self.evaluate(dataset='val', batches_seen=batches_seen) 211 | 212 | end_time = time.time() 213 | 214 | self._writer.add_scalar('training loss', 215 | np.mean(losses), 216 | batches_seen) 217 | 218 | if (epoch_num % log_every) == log_every - 1: 219 | message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}, ' \ 220 | '{:.1f}s'.format(epoch_num, epochs, batches_seen, 221 | np.mean(losses), val_loss, lr_scheduler.get_lr()[0], 222 | (end_time - start_time)) 223 | self._logger.info(message) 224 | 225 | if (epoch_num % test_every_n_epochs) == test_every_n_epochs - 1: 226 | test_loss, _ = self.evaluate(dataset='test', batches_seen=batches_seen) 227 | message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f}, ' \ 228 | '{:.1f}s'.format(epoch_num, epochs, batches_seen, 229 | np.mean(losses), test_loss, lr_scheduler.get_lr()[0], 230 | (end_time - start_time)) 231 | self._logger.info(message) 232 | 233 | if val_loss < min_val_loss: 234 | wait = 0 235 | if save_model: 236 | model_file_name = self.save_model(epoch_num) 237 | self._logger.info( 238 | 'Val loss decrease from {:.4f} to {:.4f}, ' 239 | 'saving to {}'.format(min_val_loss, val_loss, model_file_name)) 240 | min_val_loss = val_loss 241 | 242 | elif val_loss >= min_val_loss: 243 | wait += 1 244 | if wait == patience: 245 | self._logger.warning('Early stopping at epoch: %d' % epoch_num) 246 | break 247 | 248 | def _prepare_data(self, x, y): 249 | x, y = self._get_x_y(x, y) 250 | x, y = self._get_x_y_in_correct_dims(x, y) 251 | return x.to(device), y.to(device) 252 | 253 | def _get_x_y(self, x, y): 254 | """ 255 | :param x: shape (batch_size, seq_len, num_sensor, input_dim) 256 | :param y: shape (batch_size, horizon, num_sensor, input_dim) 257 | :returns x shape (seq_len, batch_size, num_sensor, input_dim) 258 | y shape (horizon, batch_size, num_sensor, input_dim) 259 | """ 260 | x = torch.from_numpy(x).float() 261 | y = torch.from_numpy(y).float() 262 | self._logger.debug("X: {}".format(x.size())) 263 | self._logger.debug("y: {}".format(y.size())) 264 | x = x.permute(1, 0, 2, 3) 265 | y = y.permute(1, 0, 2, 3) 266 | return x, y 267 | 268 | def _get_x_y_in_correct_dims(self, x, y): 269 | """ 270 | :param x: shape (seq_len, batch_size, num_sensor, input_dim) 271 | :param y: shape (horizon, batch_size, num_sensor, input_dim) 272 | :return: x: shape (seq_len, batch_size, num_sensor * input_dim) 273 | y: shape (horizon, batch_size, num_sensor * output_dim) 274 | """ 275 | batch_size = x.size(1) 276 | x = x.view(self.seq_len, batch_size, self.num_nodes * self.input_dim) 277 | y = y[..., :self.output_dim].view(self.horizon, batch_size, 278 | self.num_nodes * self.output_dim) 279 | return x, y 280 | 281 | def _compute_loss(self, y_true, y_predicted): 282 | y_true = self.standard_scaler.inverse_transform(y_true) 283 | y_predicted = self.standard_scaler.inverse_transform(y_predicted) 284 | return masked_mae_loss(y_predicted, y_true) 285 | -------------------------------------------------------------------------------- /model/pytorch/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def masked_mae_loss(y_pred, y_true): 5 | mask = (y_true != 0).float() 6 | mask /= mask.mean() 7 | loss = torch.abs(y_pred - y_true) 8 | loss = loss * mask 9 | # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3 10 | loss[loss != loss] = 0 11 | return loss.mean() 12 | -------------------------------------------------------------------------------- /model/tf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/model/tf/__init__.py -------------------------------------------------------------------------------- /model/tf/dcrnn_cell.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from tensorflow.contrib.rnn import RNNCell 9 | 10 | from lib import utils 11 | 12 | 13 | class DCGRUCell(RNNCell): 14 | """Graph Convolution Gated Recurrent Unit cell. 15 | """ 16 | 17 | def call(self, inputs, **kwargs): 18 | pass 19 | 20 | def compute_output_shape(self, input_shape): 21 | pass 22 | 23 | def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, num_proj=None, 24 | activation=tf.nn.tanh, reuse=None, filter_type="laplacian", use_gc_for_ru=True): 25 | """ 26 | 27 | :param num_units: 28 | :param adj_mx: 29 | :param max_diffusion_step: 30 | :param num_nodes: 31 | :param input_size: 32 | :param num_proj: 33 | :param activation: 34 | :param reuse: 35 | :param filter_type: "laplacian", "random_walk", "dual_random_walk". 36 | :param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates. 37 | """ 38 | super(DCGRUCell, self).__init__(_reuse=reuse) 39 | self._activation = activation 40 | self._num_nodes = num_nodes 41 | self._num_proj = num_proj 42 | self._num_units = num_units 43 | self._max_diffusion_step = max_diffusion_step 44 | self._supports = [] 45 | self._use_gc_for_ru = use_gc_for_ru 46 | supports = [] 47 | if filter_type == "laplacian": 48 | supports.append(utils.calculate_scaled_laplacian(adj_mx, lambda_max=None)) 49 | elif filter_type == "random_walk": 50 | supports.append(utils.calculate_random_walk_matrix(adj_mx).T) 51 | elif filter_type == "dual_random_walk": 52 | supports.append(utils.calculate_random_walk_matrix(adj_mx).T) 53 | supports.append(utils.calculate_random_walk_matrix(adj_mx.T).T) 54 | else: 55 | supports.append(utils.calculate_scaled_laplacian(adj_mx)) 56 | for support in supports: 57 | self._supports.append(self._build_sparse_matrix(support)) 58 | 59 | @staticmethod 60 | def _build_sparse_matrix(L): 61 | L = L.tocoo() 62 | indices = np.column_stack((L.row, L.col)) 63 | L = tf.SparseTensor(indices, L.data, L.shape) 64 | return tf.sparse_reorder(L) 65 | 66 | @property 67 | def state_size(self): 68 | return self._num_nodes * self._num_units 69 | 70 | @property 71 | def output_size(self): 72 | output_size = self._num_nodes * self._num_units 73 | if self._num_proj is not None: 74 | output_size = self._num_nodes * self._num_proj 75 | return output_size 76 | 77 | def __call__(self, inputs, state, scope=None): 78 | """Gated recurrent unit (GRU) with Graph Convolution. 79 | :param inputs: (B, num_nodes * input_dim) 80 | 81 | :return 82 | - Output: A `2-D` tensor with shape `[batch_size x self.output_size]`. 83 | - New state: Either a single `2-D` tensor, or a tuple of tensors matching 84 | the arity and shapes of `state` 85 | """ 86 | with tf.variable_scope(scope or "dcgru_cell"): 87 | with tf.variable_scope("gates"): # Reset gate and update gate. 88 | output_size = 2 * self._num_units 89 | # We start with bias of 1.0 to not reset and not update. 90 | if self._use_gc_for_ru: 91 | fn = self._gconv 92 | else: 93 | fn = self._fc 94 | value = tf.nn.sigmoid(fn(inputs, state, output_size, bias_start=1.0)) 95 | value = tf.reshape(value, (-1, self._num_nodes, output_size)) 96 | r, u = tf.split(value=value, num_or_size_splits=2, axis=-1) 97 | r = tf.reshape(r, (-1, self._num_nodes * self._num_units)) 98 | u = tf.reshape(u, (-1, self._num_nodes * self._num_units)) 99 | with tf.variable_scope("candidate"): 100 | c = self._gconv(inputs, r * state, self._num_units) 101 | if self._activation is not None: 102 | c = self._activation(c) 103 | output = new_state = u * state + (1 - u) * c 104 | if self._num_proj is not None: 105 | with tf.variable_scope("projection"): 106 | w = tf.get_variable('w', shape=(self._num_units, self._num_proj)) 107 | batch_size = inputs.get_shape()[0].value 108 | output = tf.reshape(new_state, shape=(-1, self._num_units)) 109 | output = tf.reshape(tf.matmul(output, w), shape=(batch_size, self.output_size)) 110 | return output, new_state 111 | 112 | @staticmethod 113 | def _concat(x, x_): 114 | x_ = tf.expand_dims(x_, 0) 115 | return tf.concat([x, x_], axis=0) 116 | 117 | def _fc(self, inputs, state, output_size, bias_start=0.0): 118 | dtype = inputs.dtype 119 | batch_size = inputs.get_shape()[0].value 120 | inputs = tf.reshape(inputs, (batch_size * self._num_nodes, -1)) 121 | state = tf.reshape(state, (batch_size * self._num_nodes, -1)) 122 | inputs_and_state = tf.concat([inputs, state], axis=-1) 123 | input_size = inputs_and_state.get_shape()[-1].value 124 | weights = tf.get_variable( 125 | 'weights', [input_size, output_size], dtype=dtype, 126 | initializer=tf.contrib.layers.xavier_initializer()) 127 | value = tf.nn.sigmoid(tf.matmul(inputs_and_state, weights)) 128 | biases = tf.get_variable("biases", [output_size], dtype=dtype, 129 | initializer=tf.constant_initializer(bias_start, dtype=dtype)) 130 | value = tf.nn.bias_add(value, biases) 131 | return value 132 | 133 | def _gconv(self, inputs, state, output_size, bias_start=0.0): 134 | """Graph convolution between input and the graph matrix. 135 | 136 | :param args: a 2D Tensor or a list of 2D, batch x n, Tensors. 137 | :param output_size: 138 | :param bias: 139 | :param bias_start: 140 | :param scope: 141 | :return: 142 | """ 143 | # Reshape input and state to (batch_size, num_nodes, input_dim/state_dim) 144 | batch_size = inputs.get_shape()[0].value 145 | inputs = tf.reshape(inputs, (batch_size, self._num_nodes, -1)) 146 | state = tf.reshape(state, (batch_size, self._num_nodes, -1)) 147 | inputs_and_state = tf.concat([inputs, state], axis=2) 148 | input_size = inputs_and_state.get_shape()[2].value 149 | dtype = inputs.dtype 150 | 151 | x = inputs_and_state 152 | x0 = tf.transpose(x, perm=[1, 2, 0]) # (num_nodes, total_arg_size, batch_size) 153 | x0 = tf.reshape(x0, shape=[self._num_nodes, input_size * batch_size]) 154 | x = tf.expand_dims(x0, axis=0) 155 | 156 | scope = tf.get_variable_scope() 157 | with tf.variable_scope(scope): 158 | if self._max_diffusion_step == 0: 159 | pass 160 | else: 161 | for support in self._supports: 162 | x1 = tf.sparse_tensor_dense_matmul(support, x0) 163 | x = self._concat(x, x1) 164 | 165 | for k in range(2, self._max_diffusion_step + 1): 166 | x2 = 2 * tf.sparse_tensor_dense_matmul(support, x1) - x0 167 | x = self._concat(x, x2) 168 | x1, x0 = x2, x1 169 | 170 | num_matrices = len(self._supports) * self._max_diffusion_step + 1 # Adds for x itself. 171 | x = tf.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size]) 172 | x = tf.transpose(x, perm=[3, 1, 2, 0]) # (batch_size, num_nodes, input_size, order) 173 | x = tf.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices]) 174 | 175 | weights = tf.get_variable( 176 | 'weights', [input_size * num_matrices, output_size], dtype=dtype, 177 | initializer=tf.contrib.layers.xavier_initializer()) 178 | x = tf.matmul(x, weights) # (batch_size * self._num_nodes, output_size) 179 | 180 | biases = tf.get_variable("biases", [output_size], dtype=dtype, 181 | initializer=tf.constant_initializer(bias_start, dtype=dtype)) 182 | x = tf.nn.bias_add(x, biases) 183 | # Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim) 184 | return tf.reshape(x, [batch_size, self._num_nodes * output_size]) 185 | -------------------------------------------------------------------------------- /model/tf/dcrnn_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | from tensorflow.contrib import legacy_seq2seq 8 | 9 | from model.tf.dcrnn_cell import DCGRUCell 10 | 11 | 12 | class DCRNNModel(object): 13 | def __init__(self, is_training, batch_size, scaler, adj_mx, **model_kwargs): 14 | # Scaler for data normalization. 15 | self._scaler = scaler 16 | 17 | # Train and loss 18 | self._loss = None 19 | self._mae = None 20 | self._train_op = None 21 | 22 | max_diffusion_step = int(model_kwargs.get('max_diffusion_step', 2)) 23 | cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000)) 24 | filter_type = model_kwargs.get('filter_type', 'laplacian') 25 | horizon = int(model_kwargs.get('horizon', 1)) 26 | max_grad_norm = float(model_kwargs.get('max_grad_norm', 5.0)) 27 | num_nodes = int(model_kwargs.get('num_nodes', 1)) 28 | num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1)) 29 | rnn_units = int(model_kwargs.get('rnn_units')) 30 | seq_len = int(model_kwargs.get('seq_len')) 31 | use_curriculum_learning = bool(model_kwargs.get('use_curriculum_learning', False)) 32 | input_dim = int(model_kwargs.get('input_dim', 1)) 33 | output_dim = int(model_kwargs.get('output_dim', 1)) 34 | 35 | # Input (batch_size, timesteps, num_sensor, input_dim) 36 | self._inputs = tf.placeholder(tf.float32, shape=(batch_size, seq_len, num_nodes, input_dim), name='inputs') 37 | # Labels: (batch_size, timesteps, num_sensor, input_dim), same format with input except the temporal dimension. 38 | self._labels = tf.placeholder(tf.float32, shape=(batch_size, horizon, num_nodes, input_dim), name='labels') 39 | 40 | # GO_SYMBOL = tf.zeros(shape=(batch_size, num_nodes * input_dim)) 41 | GO_SYMBOL = tf.zeros(shape=(batch_size, num_nodes * output_dim)) 42 | 43 | cell = DCGRUCell(rnn_units, adj_mx, max_diffusion_step=max_diffusion_step, num_nodes=num_nodes, 44 | filter_type=filter_type) 45 | cell_with_projection = DCGRUCell(rnn_units, adj_mx, max_diffusion_step=max_diffusion_step, num_nodes=num_nodes, 46 | num_proj=output_dim, filter_type=filter_type) 47 | encoding_cells = [cell] * num_rnn_layers 48 | decoding_cells = [cell] * (num_rnn_layers - 1) + [cell_with_projection] 49 | encoding_cells = tf.contrib.rnn.MultiRNNCell(encoding_cells, state_is_tuple=True) 50 | decoding_cells = tf.contrib.rnn.MultiRNNCell(decoding_cells, state_is_tuple=True) 51 | 52 | global_step = tf.train.get_or_create_global_step() 53 | # Outputs: (batch_size, timesteps, num_nodes, output_dim) 54 | with tf.variable_scope('DCRNN_SEQ'): 55 | inputs = tf.unstack(tf.reshape(self._inputs, (batch_size, seq_len, num_nodes * input_dim)), axis=1) 56 | labels = tf.unstack( 57 | tf.reshape(self._labels[..., :output_dim], (batch_size, horizon, num_nodes * output_dim)), axis=1) 58 | labels.insert(0, GO_SYMBOL) 59 | 60 | def _loop_function(prev, i): 61 | if is_training: 62 | # Return either the model's prediction or the previous ground truth in training. 63 | if use_curriculum_learning: 64 | c = tf.random_uniform((), minval=0, maxval=1.) 65 | threshold = self._compute_sampling_threshold(global_step, cl_decay_steps) 66 | result = tf.cond(tf.less(c, threshold), lambda: labels[i], lambda: prev) 67 | else: 68 | result = labels[i] 69 | else: 70 | # Return the prediction of the model in testing. 71 | result = prev 72 | return result 73 | 74 | _, enc_state = tf.contrib.rnn.static_rnn(encoding_cells, inputs, dtype=tf.float32) 75 | outputs, final_state = legacy_seq2seq.rnn_decoder(labels, enc_state, decoding_cells, 76 | loop_function=_loop_function) 77 | 78 | # Project the output to output_dim. 79 | outputs = tf.stack(outputs[:-1], axis=1) 80 | self._outputs = tf.reshape(outputs, (batch_size, horizon, num_nodes, output_dim), name='outputs') 81 | self._merged = tf.summary.merge_all() 82 | 83 | @staticmethod 84 | def _compute_sampling_threshold(global_step, k): 85 | """ 86 | Computes the sampling probability for scheduled sampling using inverse sigmoid. 87 | :param global_step: 88 | :param k: 89 | :return: 90 | """ 91 | return tf.cast(k / (k + tf.exp(global_step / k)), tf.float32) 92 | 93 | @property 94 | def inputs(self): 95 | return self._inputs 96 | 97 | @property 98 | def labels(self): 99 | return self._labels 100 | 101 | @property 102 | def loss(self): 103 | return self._loss 104 | 105 | @property 106 | def mae(self): 107 | return self._mae 108 | 109 | @property 110 | def merged(self): 111 | return self._merged 112 | 113 | @property 114 | def outputs(self): 115 | return self._outputs 116 | -------------------------------------------------------------------------------- /model/tf/dcrnn_supervisor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import os 7 | import sys 8 | import tensorflow as tf 9 | import time 10 | import yaml 11 | 12 | from lib import utils, metrics 13 | from lib.AMSGrad import AMSGrad 14 | from lib.metrics import masked_mae_loss 15 | 16 | from model.tf.dcrnn_model import DCRNNModel 17 | 18 | 19 | class DCRNNSupervisor(object): 20 | """ 21 | Do experiments using Graph Random Walk RNN model. 22 | """ 23 | 24 | def __init__(self, adj_mx, **kwargs): 25 | 26 | self._kwargs = kwargs 27 | self._data_kwargs = kwargs.get('data') 28 | self._model_kwargs = kwargs.get('model') 29 | self._train_kwargs = kwargs.get('train') 30 | 31 | # logging. 32 | self._log_dir = self._get_log_dir(kwargs) 33 | log_level = self._kwargs.get('log_level', 'INFO') 34 | self._logger = utils.get_logger(self._log_dir, __name__, 'info.log', level=log_level) 35 | self._writer = tf.summary.FileWriter(self._log_dir) 36 | self._logger.info(kwargs) 37 | 38 | # Data preparation 39 | self._data = utils.load_dataset(**self._data_kwargs) 40 | for k, v in self._data.items(): 41 | if hasattr(v, 'shape'): 42 | self._logger.info((k, v.shape)) 43 | 44 | # Build models. 45 | scaler = self._data['scaler'] 46 | with tf.name_scope('Train'): 47 | with tf.variable_scope('DCRNN', reuse=False): 48 | self._train_model = DCRNNModel(is_training=True, scaler=scaler, 49 | batch_size=self._data_kwargs['batch_size'], 50 | adj_mx=adj_mx, **self._model_kwargs) 51 | 52 | with tf.name_scope('Test'): 53 | with tf.variable_scope('DCRNN', reuse=True): 54 | self._test_model = DCRNNModel(is_training=False, scaler=scaler, 55 | batch_size=self._data_kwargs['test_batch_size'], 56 | adj_mx=adj_mx, **self._model_kwargs) 57 | 58 | # Learning rate. 59 | self._lr = tf.get_variable('learning_rate', shape=(), initializer=tf.constant_initializer(0.01), 60 | trainable=False) 61 | self._new_lr = tf.placeholder(tf.float32, shape=(), name='new_learning_rate') 62 | self._lr_update = tf.assign(self._lr, self._new_lr, name='lr_update') 63 | 64 | # Configure optimizer 65 | optimizer_name = self._train_kwargs.get('optimizer', 'adam').lower() 66 | epsilon = float(self._train_kwargs.get('epsilon', 1e-3)) 67 | optimizer = tf.train.AdamOptimizer(self._lr, epsilon=epsilon) 68 | if optimizer_name == 'sgd': 69 | optimizer = tf.train.GradientDescentOptimizer(self._lr, ) 70 | elif optimizer_name == 'amsgrad': 71 | optimizer = AMSGrad(self._lr, epsilon=epsilon) 72 | 73 | # Calculate loss 74 | output_dim = self._model_kwargs.get('output_dim') 75 | preds = self._train_model.outputs 76 | labels = self._train_model.labels[..., :output_dim] 77 | 78 | null_val = 0. 79 | self._loss_fn = masked_mae_loss(scaler, null_val) 80 | self._train_loss = self._loss_fn(preds=preds, labels=labels) 81 | 82 | tvars = tf.trainable_variables() 83 | grads = tf.gradients(self._train_loss, tvars) 84 | max_grad_norm = kwargs['train'].get('max_grad_norm', 1.) 85 | grads, _ = tf.clip_by_global_norm(grads, max_grad_norm) 86 | global_step = tf.train.get_or_create_global_step() 87 | self._train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step, name='train_op') 88 | 89 | max_to_keep = self._train_kwargs.get('max_to_keep', 100) 90 | self._epoch = 0 91 | self._saver = tf.train.Saver(tf.global_variables(), max_to_keep=max_to_keep) 92 | 93 | # Log model statistics. 94 | total_trainable_parameter = utils.get_total_trainable_parameter_size() 95 | self._logger.info('Total number of trainable parameters: {:d}'.format(total_trainable_parameter)) 96 | for var in tf.global_variables(): 97 | self._logger.debug('{}, {}'.format(var.name, var.get_shape())) 98 | 99 | @staticmethod 100 | def _get_log_dir(kwargs): 101 | log_dir = kwargs['train'].get('log_dir') 102 | if log_dir is None: 103 | batch_size = kwargs['data'].get('batch_size') 104 | learning_rate = kwargs['train'].get('base_lr') 105 | max_diffusion_step = kwargs['model'].get('max_diffusion_step') 106 | num_rnn_layers = kwargs['model'].get('num_rnn_layers') 107 | rnn_units = kwargs['model'].get('rnn_units') 108 | structure = '-'.join( 109 | ['%d' % rnn_units for _ in range(num_rnn_layers)]) 110 | horizon = kwargs['model'].get('horizon') 111 | filter_type = kwargs['model'].get('filter_type') 112 | filter_type_abbr = 'L' 113 | if filter_type == 'random_walk': 114 | filter_type_abbr = 'R' 115 | elif filter_type == 'dual_random_walk': 116 | filter_type_abbr = 'DR' 117 | run_id = 'dcrnn_%s_%d_h_%d_%s_lr_%g_bs_%d_%s/' % ( 118 | filter_type_abbr, max_diffusion_step, horizon, 119 | structure, learning_rate, batch_size, 120 | time.strftime('%m%d%H%M%S')) 121 | base_dir = kwargs.get('base_dir') 122 | log_dir = os.path.join(base_dir, run_id) 123 | if not os.path.exists(log_dir): 124 | os.makedirs(log_dir) 125 | return log_dir 126 | 127 | def run_epoch_generator(self, sess, model, data_generator, return_output=False, training=False, writer=None): 128 | losses = [] 129 | maes = [] 130 | outputs = [] 131 | output_dim = self._model_kwargs.get('output_dim') 132 | preds = model.outputs 133 | labels = model.labels[..., :output_dim] 134 | loss = self._loss_fn(preds=preds, labels=labels) 135 | fetches = { 136 | 'loss': loss, 137 | 'mae': loss, 138 | 'global_step': tf.train.get_or_create_global_step() 139 | } 140 | if training: 141 | fetches.update({ 142 | 'train_op': self._train_op 143 | }) 144 | merged = model.merged 145 | if merged is not None: 146 | fetches.update({'merged': merged}) 147 | 148 | if return_output: 149 | fetches.update({ 150 | 'outputs': model.outputs 151 | }) 152 | 153 | for _, (x, y) in enumerate(data_generator): 154 | feed_dict = { 155 | model.inputs: x, 156 | model.labels: y, 157 | } 158 | 159 | vals = sess.run(fetches, feed_dict=feed_dict) 160 | 161 | losses.append(vals['loss']) 162 | maes.append(vals['mae']) 163 | if writer is not None and 'merged' in vals: 164 | writer.add_summary(vals['merged'], global_step=vals['global_step']) 165 | if return_output: 166 | outputs.append(vals['outputs']) 167 | 168 | results = { 169 | 'loss': np.mean(losses), 170 | 'mae': np.mean(maes) 171 | } 172 | if return_output: 173 | results['outputs'] = outputs 174 | return results 175 | 176 | def get_lr(self, sess): 177 | return np.asscalar(sess.run(self._lr)) 178 | 179 | def set_lr(self, sess, lr): 180 | sess.run(self._lr_update, feed_dict={ 181 | self._new_lr: lr 182 | }) 183 | 184 | def train(self, sess, **kwargs): 185 | kwargs.update(self._train_kwargs) 186 | return self._train(sess, **kwargs) 187 | 188 | def _train(self, sess, base_lr, epoch, steps, patience=50, epochs=100, 189 | min_learning_rate=2e-6, lr_decay_ratio=0.1, save_model=1, 190 | test_every_n_epochs=10, **train_kwargs): 191 | history = [] 192 | min_val_loss = float('inf') 193 | wait = 0 194 | 195 | max_to_keep = train_kwargs.get('max_to_keep', 100) 196 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=max_to_keep) 197 | model_filename = train_kwargs.get('model_filename') 198 | if model_filename is not None: 199 | saver.restore(sess, model_filename) 200 | self._epoch = epoch + 1 201 | else: 202 | sess.run(tf.global_variables_initializer()) 203 | self._logger.info('Start training ...') 204 | 205 | while self._epoch <= epochs: 206 | # Learning rate schedule. 207 | new_lr = max(min_learning_rate, base_lr * (lr_decay_ratio ** np.sum(self._epoch >= np.array(steps)))) 208 | self.set_lr(sess=sess, lr=new_lr) 209 | 210 | start_time = time.time() 211 | train_results = self.run_epoch_generator(sess, self._train_model, 212 | self._data['train_loader'].get_iterator(), 213 | training=True, 214 | writer=self._writer) 215 | train_loss, train_mae = train_results['loss'], train_results['mae'] 216 | if train_loss > 1e5: 217 | self._logger.warning('Gradient explosion detected. Ending...') 218 | break 219 | 220 | global_step = sess.run(tf.train.get_or_create_global_step()) 221 | # Compute validation error. 222 | val_results = self.run_epoch_generator(sess, self._test_model, 223 | self._data['val_loader'].get_iterator(), 224 | training=False) 225 | val_loss, val_mae = np.asscalar(val_results['loss']), np.asscalar(val_results['mae']) 226 | 227 | utils.add_simple_summary(self._writer, 228 | ['loss/train_loss', 'metric/train_mae', 'loss/val_loss', 'metric/val_mae'], 229 | [train_loss, train_mae, val_loss, val_mae], global_step=global_step) 230 | end_time = time.time() 231 | message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f} lr:{:.6f} {:.1f}s'.format( 232 | self._epoch, epochs, global_step, train_mae, val_mae, new_lr, (end_time - start_time)) 233 | self._logger.info(message) 234 | if self._epoch % test_every_n_epochs == test_every_n_epochs - 1: 235 | self.evaluate(sess) 236 | if val_loss <= min_val_loss: 237 | wait = 0 238 | if save_model > 0: 239 | model_filename = self.save(sess, val_loss) 240 | self._logger.info( 241 | 'Val loss decrease from %.4f to %.4f, saving to %s' % (min_val_loss, val_loss, model_filename)) 242 | min_val_loss = val_loss 243 | else: 244 | wait += 1 245 | if wait > patience: 246 | self._logger.warning('Early stopping at epoch: %d' % self._epoch) 247 | break 248 | 249 | history.append(val_mae) 250 | # Increases epoch. 251 | self._epoch += 1 252 | 253 | sys.stdout.flush() 254 | return np.min(history) 255 | 256 | def evaluate(self, sess, **kwargs): 257 | global_step = sess.run(tf.train.get_or_create_global_step()) 258 | test_results = self.run_epoch_generator(sess, self._test_model, 259 | self._data['test_loader'].get_iterator(), 260 | return_output=True, 261 | training=False) 262 | 263 | # y_preds: a list of (batch_size, horizon, num_nodes, output_dim) 264 | test_loss, y_preds = test_results['loss'], test_results['outputs'] 265 | utils.add_simple_summary(self._writer, ['loss/test_loss'], [test_loss], global_step=global_step) 266 | 267 | y_preds = np.concatenate(y_preds, axis=0) 268 | scaler = self._data['scaler'] 269 | predictions = [] 270 | y_truths = [] 271 | for horizon_i in range(self._data['y_test'].shape[1]): 272 | y_truth = scaler.inverse_transform(self._data['y_test'][:, horizon_i, :, 0]) 273 | y_truths.append(y_truth) 274 | 275 | y_pred = scaler.inverse_transform(y_preds[:y_truth.shape[0], horizon_i, :, 0]) 276 | predictions.append(y_pred) 277 | 278 | mae = metrics.masked_mae_np(y_pred, y_truth, null_val=0) 279 | mape = metrics.masked_mape_np(y_pred, y_truth, null_val=0) 280 | rmse = metrics.masked_rmse_np(y_pred, y_truth, null_val=0) 281 | self._logger.info( 282 | "Horizon {:02d}, MAE: {:.2f}, MAPE: {:.4f}, RMSE: {:.2f}".format( 283 | horizon_i + 1, mae, mape, rmse 284 | ) 285 | ) 286 | utils.add_simple_summary(self._writer, 287 | ['%s_%d' % (item, horizon_i + 1) for item in 288 | ['metric/rmse', 'metric/mape', 'metric/mae']], 289 | [rmse, mape, mae], 290 | global_step=global_step) 291 | outputs = { 292 | 'predictions': predictions, 293 | 'groundtruth': y_truths 294 | } 295 | return outputs 296 | 297 | def load(self, sess, model_filename): 298 | """ 299 | Restore from saved model. 300 | :param sess: 301 | :param model_filename: 302 | :return: 303 | """ 304 | self._saver.restore(sess, model_filename) 305 | 306 | def save(self, sess, val_loss): 307 | config = dict(self._kwargs) 308 | global_step = np.asscalar(sess.run(tf.train.get_or_create_global_step())) 309 | prefix = os.path.join(self._log_dir, 'models-{:.4f}'.format(val_loss)) 310 | config['train']['epoch'] = self._epoch 311 | config['train']['global_step'] = global_step 312 | config['train']['log_dir'] = self._log_dir 313 | config['train']['model_filename'] = self._saver.save(sess, prefix, global_step=global_step, 314 | write_meta_graph=False) 315 | config_filename = 'config_{}.yaml'.format(self._epoch) 316 | with open(os.path.join(self._log_dir, config_filename), 'w') as f: 317 | yaml.dump(config, f, default_flow_style=False) 318 | return config['train']['model_filename'] 319 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | scipy>=0.19.0 3 | numpy>=1.12.1 4 | pandas>=0.19.2 5 | pyyaml 6 | statsmodels 7 | tensorflow>=1.3.0 8 | torch 9 | tables 10 | future -------------------------------------------------------------------------------- /run_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import sys 5 | import tensorflow as tf 6 | import yaml 7 | 8 | from lib.utils import load_graph_data 9 | from model.tf.dcrnn_supervisor import DCRNNSupervisor 10 | 11 | 12 | def run_dcrnn(args): 13 | with open(args.config_filename) as f: 14 | config = yaml.load(f) 15 | tf_config = tf.ConfigProto() 16 | if args.use_cpu_only: 17 | tf_config = tf.ConfigProto(device_count={'GPU': 0}) 18 | tf_config.gpu_options.allow_growth = True 19 | graph_pkl_filename = config['data']['graph_pkl_filename'] 20 | _, _, adj_mx = load_graph_data(graph_pkl_filename) 21 | with tf.Session(config=tf_config) as sess: 22 | supervisor = DCRNNSupervisor(adj_mx=adj_mx, **config) 23 | supervisor.load(sess, config['train']['model_filename']) 24 | outputs = supervisor.evaluate(sess) 25 | np.savez_compressed(args.output_filename, **outputs) 26 | print('Predictions saved as {}.'.format(args.output_filename)) 27 | 28 | 29 | if __name__ == '__main__': 30 | sys.path.append(os.getcwd()) 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--use_cpu_only', default=False, type=str, help='Whether to run tensorflow on cpu.') 33 | parser.add_argument('--config_filename', default='data/model/pretrained/METR-LA/config.yaml', type=str, 34 | help='Config file for pretrained model.') 35 | parser.add_argument('--output_filename', default='data/dcrnn_predictions_tf.npz') 36 | args = parser.parse_args() 37 | run_dcrnn(args) 38 | -------------------------------------------------------------------------------- /run_demo_pytorch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import sys 5 | import yaml 6 | 7 | from lib.utils import load_graph_data 8 | from model.pytorch.dcrnn_supervisor import DCRNNSupervisor 9 | 10 | 11 | def run_dcrnn(args): 12 | with open(args.config_filename) as f: 13 | supervisor_config = yaml.load(f) 14 | 15 | graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename') 16 | sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(graph_pkl_filename) 17 | 18 | supervisor = DCRNNSupervisor(adj_mx=adj_mx, **supervisor_config) 19 | mean_score, outputs = supervisor.evaluate('test') 20 | np.savez_compressed(args.output_filename, **outputs) 21 | print("MAE : {}".format(mean_score)) 22 | print('Predictions saved as {}.'.format(args.output_filename)) 23 | 24 | 25 | if __name__ == '__main__': 26 | sys.path.append(os.getcwd()) 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--use_cpu_only', default=False, type=str, help='Whether to run tensorflow on cpu.') 29 | parser.add_argument('--config_filename', default='data/model/pretrained/METR-LA/config.yaml', type=str, 30 | help='Config file for pretrained model.') 31 | parser.add_argument('--output_filename', default='data/dcrnn_predictions.npz') 32 | args = parser.parse_args() 33 | run_dcrnn(args) 34 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chnsh/DCRNN_PyTorch/d92490b808ba5c5be2f23d427d96e9a56b066d7f/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/eval_baseline_methods.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pandas as pd 4 | 5 | from statsmodels.tsa.vector_ar.var_model import VAR 6 | 7 | from lib import utils 8 | from lib.metrics import masked_rmse_np, masked_mape_np, masked_mae_np 9 | from lib.utils import StandardScaler 10 | 11 | 12 | def historical_average_predict(df, period=12 * 24 * 7, test_ratio=0.2, null_val=0.): 13 | """ 14 | Calculates the historical average of sensor reading. 15 | :param df: 16 | :param period: default 1 week. 17 | :param test_ratio: 18 | :param null_val: default 0. 19 | :return: 20 | """ 21 | n_sample, n_sensor = df.shape 22 | n_test = int(round(n_sample * test_ratio)) 23 | n_train = n_sample - n_test 24 | y_test = df[-n_test:] 25 | y_predict = pd.DataFrame.copy(y_test) 26 | 27 | for i in range(n_train, min(n_sample, n_train + period)): 28 | inds = [j for j in range(i % period, n_train, period)] 29 | historical = df.iloc[inds, :] 30 | y_predict.iloc[i - n_train, :] = historical[historical != null_val].mean() 31 | # Copy each period. 32 | for i in range(n_train + period, n_sample, period): 33 | size = min(period, n_sample - i) 34 | start = i - n_train 35 | y_predict.iloc[start:start + size, :] = y_predict.iloc[start - period: start + size - period, :].values 36 | return y_predict, y_test 37 | 38 | 39 | def static_predict(df, n_forward, test_ratio=0.2): 40 | """ 41 | Assumes $x^{t+1} = x^{t}$ 42 | :param df: 43 | :param n_forward: 44 | :param test_ratio: 45 | :return: 46 | """ 47 | test_num = int(round(df.shape[0] * test_ratio)) 48 | y_test = df[-test_num:] 49 | y_predict = df.shift(n_forward).iloc[-test_num:] 50 | return y_predict, y_test 51 | 52 | 53 | def var_predict(df, n_forwards=(1, 3), n_lags=4, test_ratio=0.2): 54 | """ 55 | Multivariate time series forecasting using Vector Auto-Regressive Model. 56 | :param df: pandas.DataFrame, index: time, columns: sensor id, content: data. 57 | :param n_forwards: a tuple of horizons. 58 | :param n_lags: the order of the VAR model. 59 | :param test_ratio: 60 | :return: [list of prediction in different horizon], dt_test 61 | """ 62 | n_sample, n_output = df.shape 63 | n_test = int(round(n_sample * test_ratio)) 64 | n_train = n_sample - n_test 65 | df_train, df_test = df[:n_train], df[n_train:] 66 | 67 | scaler = StandardScaler(mean=df_train.values.mean(), std=df_train.values.std()) 68 | data = scaler.transform(df_train.values) 69 | var_model = VAR(data) 70 | var_result = var_model.fit(n_lags) 71 | max_n_forwards = np.max(n_forwards) 72 | # Do forecasting. 73 | result = np.zeros(shape=(len(n_forwards), n_test, n_output)) 74 | start = n_train - n_lags - max_n_forwards + 1 75 | for input_ind in range(start, n_sample - n_lags): 76 | prediction = var_result.forecast(scaler.transform(df.values[input_ind: input_ind + n_lags]), max_n_forwards) 77 | for i, n_forward in enumerate(n_forwards): 78 | result_ind = input_ind - n_train + n_lags + n_forward - 1 79 | if 0 <= result_ind < n_test: 80 | result[i, result_ind, :] = prediction[n_forward - 1, :] 81 | 82 | df_predicts = [] 83 | for i, n_forward in enumerate(n_forwards): 84 | df_predict = pd.DataFrame(scaler.inverse_transform(result[i]), index=df_test.index, columns=df_test.columns) 85 | df_predicts.append(df_predict) 86 | return df_predicts, df_test 87 | 88 | 89 | def eval_static(traffic_reading_df): 90 | logger.info('Static') 91 | horizons = [1, 3, 6, 12] 92 | logger.info('\t'.join(['Model', 'Horizon', 'RMSE', 'MAPE', 'MAE'])) 93 | for horizon in horizons: 94 | y_predict, y_test = static_predict(traffic_reading_df, n_forward=horizon, test_ratio=0.2) 95 | rmse = masked_rmse_np(preds=y_predict.as_matrix(), labels=y_test.as_matrix(), null_val=0) 96 | mape = masked_mape_np(preds=y_predict.as_matrix(), labels=y_test.as_matrix(), null_val=0) 97 | mae = masked_mae_np(preds=y_predict.as_matrix(), labels=y_test.as_matrix(), null_val=0) 98 | line = 'Static\t%d\t%.2f\t%.2f\t%.2f' % (horizon, rmse, mape * 100, mae) 99 | logger.info(line) 100 | 101 | 102 | def eval_historical_average(traffic_reading_df, period): 103 | y_predict, y_test = historical_average_predict(traffic_reading_df, period=period, test_ratio=0.2) 104 | rmse = masked_rmse_np(preds=y_predict.as_matrix(), labels=y_test.as_matrix(), null_val=0) 105 | mape = masked_mape_np(preds=y_predict.as_matrix(), labels=y_test.as_matrix(), null_val=0) 106 | mae = masked_mae_np(preds=y_predict.as_matrix(), labels=y_test.as_matrix(), null_val=0) 107 | logger.info('Historical Average') 108 | logger.info('\t'.join(['Model', 'Horizon', 'RMSE', 'MAPE', 'MAE'])) 109 | for horizon in [1, 3, 6, 12]: 110 | line = 'HA\t%d\t%.2f\t%.2f\t%.2f' % (horizon, rmse, mape * 100, mae) 111 | logger.info(line) 112 | 113 | 114 | def eval_var(traffic_reading_df, n_lags=3): 115 | n_forwards = [1, 3, 6, 12] 116 | y_predicts, y_test = var_predict(traffic_reading_df, n_forwards=n_forwards, n_lags=n_lags, 117 | test_ratio=0.2) 118 | logger.info('VAR (lag=%d)' % n_lags) 119 | logger.info('Model\tHorizon\tRMSE\tMAPE\tMAE') 120 | for i, horizon in enumerate(n_forwards): 121 | rmse = masked_rmse_np(preds=y_predicts[i].as_matrix(), labels=y_test.as_matrix(), null_val=0) 122 | mape = masked_mape_np(preds=y_predicts[i].as_matrix(), labels=y_test.as_matrix(), null_val=0) 123 | mae = masked_mae_np(preds=y_predicts[i].as_matrix(), labels=y_test.as_matrix(), null_val=0) 124 | line = 'VAR\t%d\t%.2f\t%.2f\t%.2f' % (horizon, rmse, mape * 100, mae) 125 | logger.info(line) 126 | 127 | 128 | def main(args): 129 | traffic_reading_df = pd.read_hdf(args.traffic_reading_filename) 130 | eval_static(traffic_reading_df) 131 | eval_historical_average(traffic_reading_df, period=7 * 24 * 12) 132 | eval_var(traffic_reading_df, n_lags=3) 133 | 134 | 135 | if __name__ == '__main__': 136 | logger = utils.get_logger('data/model', 'Baseline') 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument('--traffic_reading_filename', default="data/metr-la.h5", type=str, 139 | help='Path to the traffic Dataframe.') 140 | args = parser.parse_args() 141 | main(args) 142 | -------------------------------------------------------------------------------- /scripts/gen_adj_mx.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import numpy as np 7 | import pandas as pd 8 | import pickle 9 | 10 | 11 | def get_adjacency_matrix(distance_df, sensor_ids, normalized_k=0.1): 12 | """ 13 | 14 | :param distance_df: data frame with three columns: [from, to, distance]. 15 | :param sensor_ids: list of sensor ids. 16 | :param normalized_k: entries that become lower than normalized_k after normalization are set to zero for sparsity. 17 | :return: 18 | """ 19 | num_sensors = len(sensor_ids) 20 | dist_mx = np.zeros((num_sensors, num_sensors), dtype=np.float32) 21 | dist_mx[:] = np.inf 22 | # Builds sensor id to index map. 23 | sensor_id_to_ind = {} 24 | for i, sensor_id in enumerate(sensor_ids): 25 | sensor_id_to_ind[sensor_id] = i 26 | 27 | # Fills cells in the matrix with distances. 28 | for row in distance_df.values: 29 | if row[0] not in sensor_id_to_ind or row[1] not in sensor_id_to_ind: 30 | continue 31 | dist_mx[sensor_id_to_ind[row[0]], sensor_id_to_ind[row[1]]] = row[2] 32 | 33 | # Calculates the standard deviation as theta. 34 | distances = dist_mx[~np.isinf(dist_mx)].flatten() 35 | std = distances.std() 36 | adj_mx = np.exp(-np.square(dist_mx / std)) 37 | # Make the adjacent matrix symmetric by taking the max. 38 | # adj_mx = np.maximum.reduce([adj_mx, adj_mx.T]) 39 | 40 | # Sets entries that lower than a threshold, i.e., k, to zero for sparsity. 41 | adj_mx[adj_mx < normalized_k] = 0 42 | return sensor_ids, sensor_id_to_ind, adj_mx 43 | 44 | 45 | if __name__ == '__main__': 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--sensor_ids_filename', type=str, default='data/sensor_graph/graph_sensor_ids.txt', 48 | help='File containing sensor ids separated by comma.') 49 | parser.add_argument('--distances_filename', type=str, default='data/sensor_graph/distances_la_2012.csv', 50 | help='CSV file containing sensor distances with three columns: [from, to, distance].') 51 | parser.add_argument('--normalized_k', type=float, default=0.1, 52 | help='Entries that become lower than normalized_k after normalization are set to zero for sparsity.') 53 | parser.add_argument('--output_pkl_filename', type=str, default='data/sensor_graph/adj_mat.pkl', 54 | help='Path of the output file.') 55 | args = parser.parse_args() 56 | 57 | with open(args.sensor_ids_filename) as f: 58 | sensor_ids = f.read().strip().split(',') 59 | distance_df = pd.read_csv(args.distances_filename, dtype={'from': 'str', 'to': 'str'}) 60 | _, sensor_id_to_ind, adj_mx = get_adjacency_matrix(distance_df, sensor_ids) 61 | # Save to pickle file. 62 | with open(args.output_pkl_filename, 'wb') as f: 63 | pickle.dump([sensor_ids, sensor_id_to_ind, adj_mx], f, protocol=2) 64 | -------------------------------------------------------------------------------- /scripts/generate_training_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import argparse 7 | import numpy as np 8 | import os 9 | import pandas as pd 10 | 11 | 12 | def generate_graph_seq2seq_io_data( 13 | df, x_offsets, y_offsets, add_time_in_day=True, add_day_in_week=False, scaler=None 14 | ): 15 | """ 16 | Generate samples from 17 | :param df: 18 | :param x_offsets: 19 | :param y_offsets: 20 | :param add_time_in_day: 21 | :param add_day_in_week: 22 | :param scaler: 23 | :return: 24 | # x: (epoch_size, input_length, num_nodes, input_dim) 25 | # y: (epoch_size, output_length, num_nodes, output_dim) 26 | """ 27 | 28 | num_samples, num_nodes = df.shape 29 | data = np.expand_dims(df.values, axis=-1) 30 | data_list = [data] 31 | if add_time_in_day: 32 | time_ind = (df.index.values - df.index.values.astype("datetime64[D]")) / np.timedelta64(1, "D") 33 | time_in_day = np.tile(time_ind, [1, num_nodes, 1]).transpose((2, 1, 0)) 34 | data_list.append(time_in_day) 35 | if add_day_in_week: 36 | day_in_week = np.zeros(shape=(num_samples, num_nodes, 7)) 37 | day_in_week[np.arange(num_samples), :, df.index.dayofweek] = 1 38 | data_list.append(day_in_week) 39 | 40 | data = np.concatenate(data_list, axis=-1) 41 | # epoch_len = num_samples + min(x_offsets) - max(y_offsets) 42 | x, y = [], [] 43 | # t is the index of the last observation. 44 | min_t = abs(min(x_offsets)) 45 | max_t = abs(num_samples - abs(max(y_offsets))) # Exclusive 46 | for t in range(min_t, max_t): 47 | x_t = data[t + x_offsets, ...] 48 | y_t = data[t + y_offsets, ...] 49 | x.append(x_t) 50 | y.append(y_t) 51 | x = np.stack(x, axis=0) 52 | y = np.stack(y, axis=0) 53 | return x, y 54 | 55 | 56 | def generate_train_val_test(args): 57 | df = pd.read_hdf(args.traffic_df_filename) 58 | # 0 is the latest observed sample. 59 | x_offsets = np.sort( 60 | # np.concatenate(([-week_size + 1, -day_size + 1], np.arange(-11, 1, 1))) 61 | np.concatenate((np.arange(-11, 1, 1),)) 62 | ) 63 | # Predict the next one hour 64 | y_offsets = np.sort(np.arange(1, 13, 1)) 65 | # x: (num_samples, input_length, num_nodes, input_dim) 66 | # y: (num_samples, output_length, num_nodes, output_dim) 67 | x, y = generate_graph_seq2seq_io_data( 68 | df, 69 | x_offsets=x_offsets, 70 | y_offsets=y_offsets, 71 | add_time_in_day=True, 72 | add_day_in_week=False, 73 | ) 74 | 75 | print("x shape: ", x.shape, ", y shape: ", y.shape) 76 | # Write the data into npz file. 77 | # num_test = 6831, using the last 6831 examples as testing. 78 | # for the rest: 7/8 is used for training, and 1/8 is used for validation. 79 | num_samples = x.shape[0] 80 | num_test = round(num_samples * 0.2) 81 | num_train = round(num_samples * 0.7) 82 | num_val = num_samples - num_test - num_train 83 | 84 | # train 85 | x_train, y_train = x[:num_train], y[:num_train] 86 | # val 87 | x_val, y_val = ( 88 | x[num_train: num_train + num_val], 89 | y[num_train: num_train + num_val], 90 | ) 91 | # test 92 | x_test, y_test = x[-num_test:], y[-num_test:] 93 | 94 | for cat in ["train", "val", "test"]: 95 | _x, _y = locals()["x_" + cat], locals()["y_" + cat] 96 | print(cat, "x: ", _x.shape, "y:", _y.shape) 97 | np.savez_compressed( 98 | os.path.join(args.output_dir, "%s.npz" % cat), 99 | x=_x, 100 | y=_y, 101 | x_offsets=x_offsets.reshape(list(x_offsets.shape) + [1]), 102 | y_offsets=y_offsets.reshape(list(y_offsets.shape) + [1]), 103 | ) 104 | 105 | 106 | def main(args): 107 | print("Generating training data") 108 | generate_train_val_test(args) 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument( 114 | "--output_dir", type=str, default="data/", help="Output directory." 115 | ) 116 | parser.add_argument( 117 | "--traffic_df_filename", 118 | type=str, 119 | default="data/metr-la.h5", 120 | help="Raw traffic readings.", 121 | ) 122 | args = parser.parse_args() 123 | main(args) 124 | --------------------------------------------------------------------------------