├── Colab_run_model.ipynb ├── Implementation details.pdf ├── README.md ├── config_user.json ├── data_prepare ├── 1.1-data_population_inflow.py ├── 1.2-data_external_variable.py ├── 1.3-data_libcity_format.py └── 1.4-data_prepare_plot.py ├── figures ├── DC.png ├── framework.png ├── topbott__BM.png └── topbott__DC.png ├── libcity ├── __init__.py ├── config │ ├── __init__.py │ ├── config_parser.py │ ├── data │ │ ├── MTHDataset.json │ │ ├── TrafficStateDataset.json │ │ └── TrafficStatePointDataset.json │ ├── evaluator │ │ └── TrafficStateEvaluator.json │ ├── executor │ │ └── TrafficStateExecutor.json │ ├── model │ │ └── traffic_state_pred │ │ │ └── MultiATGCN.json │ └── task_config.json ├── data │ ├── __init__.py │ ├── batch.py │ ├── dataset │ │ ├── __init__.py │ │ ├── abstract_dataset.py │ │ ├── dataset_subclass │ │ │ ├── __init__.py │ │ │ └── mth_dataset.py │ │ ├── traffic_state_datatset.py │ │ └── traffic_state_point_dataset.py │ ├── list_dataset.py │ └── utils.py ├── evaluator │ ├── __init__.py │ ├── abstract_evaluator.py │ ├── eval_funcs.py │ ├── traffic_state_evaluator.py │ └── utils.py ├── executor │ ├── __init__.py │ ├── abstract_executor.py │ ├── hyper_tuning.py │ └── traffic_state_executor.py ├── model │ ├── __init__.py │ ├── abstract_model.py │ ├── abstract_traffic_state_model.py │ ├── abstract_traffic_tradition_model.py │ ├── loss.py │ ├── traffic_flow_prediction │ │ ├── MultiATGCN.py │ │ └── __init__.py │ └── utils.py ├── pipeline │ ├── __init__.py │ └── pipeline.py ├── temp │ ├── 1.4-data_prepare_plot_POI.py │ ├── MultiATGCN-2NE.py │ ├── MultiATGCN-37TP.py │ ├── MultiATGCN-3NE.py │ ├── MultiATGCN-3TU-7POI.py │ ├── MultiATGCN-3TU-FULL.py │ ├── MultiATGCN-3TU.py │ ├── MultiATGCN-3TUSimple.py │ ├── MultiATGCN-7POI.py │ ├── MultiATGCN-APT.py │ ├── MultiATGCN-POI.py │ ├── MultiATGCN-S2S.py │ ├── MultiATGCN-Traffic.py │ ├── MultiATGCN-cc.py │ ├── MultiATGCN-skip.py │ ├── MultiATGCN-weather.py │ ├── MultiATGCN3UT.py │ ├── MultiATGCN_1011.py │ ├── MultiATGCN_1014.py │ ├── MultiATGCN_bestby0929.py │ ├── MultiATGCN_external.py │ ├── STSGCN.py │ ├── STTN.py │ ├── Seq2Seq.py │ ├── [F]MultiATGCN_3TU_Before.py │ ├── __init__.py │ ├── result_convert.py │ ├── result_convert_local_old.py │ └── temp.py └── utils │ ├── __init__.py │ ├── argument_list.py │ ├── dataset.py │ ├── normalization.py │ ├── utils.py │ └── visualize.py ├── other_data ├── CTractFIPS_201901010601_BM_visit_mstd.pkl └── CTractFIPS_201901010601_DC_visit_mstd.pkl ├── raw_data ├── 201901010601_BM_SG_CTractFIPS_Hourly_Single_GP │ ├── 201901010601_BM_SG_CTractFIPS_Hourly_Single_GP.7z │ └── config.json └── 201901010601_DC_SG_CTractFIPS_Hourly_Single_GP │ ├── 201901010601_DC_SG_CTractFIPS_Hourly_Single_GP.7z │ └── config.json ├── requirements.txt ├── result_convert.py ├── result_plot.py ├── run_model.py └── run_model_parameter.py /Colab_run_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "id": "kYfsiR14WCKI" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "from google.colab import drive\n", 12 | "drive.mount('/content/drive')\n", 13 | "# %cd /content/drive/Othercomputers/My Computer/PycharmProjects/Bigscity-LibCity/\n", 14 | "%cd /content/drive/MyDrive/MultiSTGraph/" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "id": "IrDMAZLQyj-Q" 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "gpu_info = !nvidia-smi\n", 26 | "gpu_info = '\\n'.join(gpu_info)\n", 27 | "if gpu_info.find('failed') >= 0:\n", 28 | " print('Not connected to a GPU')\n", 29 | "else:\n", 30 | " print(gpu_info)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": { 37 | "id": "EeG0Co7Wa1z4" 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "# import torchtext\n", 42 | "# torchtext.__version__\n", 43 | "!pip3 install ray\n", 44 | "!pip3 install -U torchtext\n", 45 | "!pip3 install -U hyperopt\n", 46 | "!pip3 install dgl\n", 47 | "!pip3 install dtaidistance\n", 48 | "!pip3 install --upgrade gensim\n", 49 | "!pip3 install torch==1.10.2+cu113 torchvision==0.11.3+cu113 -f https://download.pytorch.org/whl/torch_stable.html \n", 50 | "# !pip3 install torchtext==0.10.0" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "id": "qrg6iJn-YaRv" 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "import torch\n", 62 | "print(torch.__version__)\n", 63 | "import argparse\n", 64 | "from libcity.pipeline import run_model\n", 65 | "from libcity.utils import str2bool, add_general_args\n", 66 | "\n", 67 | "model_list = ['MultiATGCN']\n", 68 | "if __name__ == '__main__':\n", 69 | " for model_name in model_list:\n", 70 | " parser = argparse.ArgumentParser()\n", 71 | " parser.add_argument('--task', type=str, default='traffic_state_pred', help='the name of task')\n", 72 | " parser.add_argument('--model', type=str, default=model_name, help='the name of model')\n", 73 | " parser.add_argument('--dataset', type=str, default='201901010601_BM_SG_CTractFIPS_Hourly_Single_GP',\n", 74 | " help='the name of dataset')\n", 75 | " parser.add_argument('--config_file', type=str, default='config_user', help='the file name of config file')\n", 76 | " parser.add_argument('--saved_model', type=str2bool, default=True, help='whether save the trained model')\n", 77 | " parser.add_argument('--train', type=str2bool, default=True, help='whether re-train if the model is trained')\n", 78 | " parser.add_argument('--exp_id', type=str, default=None, help='id of experiment')\n", 79 | " parser.add_argument('--seed', type=int, default=0, help='random seed')\n", 80 | " parser.add_argument('--start_dim', type=int, default=0, help='start_dim')\n", 81 | " parser.add_argument('--end_dim', type=int, default=1, help='end_dim')\n", 82 | " add_general_args(parser)\n", 83 | " args, unknown = parser.parse_known_args()\n", 84 | " dict_args = vars(args)\n", 85 | " other_args = {key: val for key, val in dict_args.items() if key not in\n", 86 | " ['task', 'model', 'dataset', 'config_file', 'saved_model', 'train'] and val is not None}\n", 87 | " run_model(task=args.task, model_name=args.model, dataset_name=args.dataset, config_file=args.config_file,\n", 88 | " saved_model=args.saved_model, train=args.train, other_args=other_args)" 89 | ] 90 | } 91 | ], 92 | "metadata": { 93 | "colab": { 94 | "collapsed_sections": [], 95 | "machine_shape": "hm", 96 | "provenance": [] 97 | }, 98 | "gpuClass": "standard", 99 | "kernelspec": { 100 | "display_name": "Python 3", 101 | "name": "python3" 102 | }, 103 | "language_info": { 104 | "name": "python" 105 | }, 106 | "accelerator": "GPU" 107 | }, 108 | "nbformat": 4, 109 | "nbformat_minor": 0 110 | } 111 | -------------------------------------------------------------------------------- /Implementation details.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SonghuaHu-UMD/MultiSTGraph/96fab1754905336cca08cc83200b6e81f3aa6c43/Implementation details.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-ATGCN: A Multi-View Graph Neural Network-based Framework for Citywide Crowd Inflow Forecasting 2 | 3 | ![Multi-ATGCN](figures/framework.png "Model Architecture") 4 | 5 | This is an original PyTorch implementation of Multi-ATGCN in the following working paper: 6 | 7 | **Multi-ATGCN: A Multi-View Graph Neural Network-based Framework for Citywide Crowd Inflow Forecasting** 8 | 9 | ## Environment 10 | We use the torch == 1.10.2 and Python 3.6.11 for implementation. 11 | 12 | We follow the framework of [LibCity](https://github.com/LibCity/Bigscity-LibCity) to prepare data and run the model. 13 | See more details in the `requirement.txt` and the [LibCity document](https://bigscity-libcity-docs.readthedocs.io/en/latest/index.html). 14 | 15 | 21 | 22 | ## Data Preparation 23 | Data files for Washington, D.C. and Baltimore City are available at the `raw_data/` folder. Please extract them to the current fold and 24 | you will get a set of atomic files following the LibCity Unified Data Structures: 25 | 26 | | filename | content | example | 27 | |-------------|------------------------------------------------------------------------------|------------------------------------------------------| 28 | | xxx.geo | Store geographic entity attribute information. | geo_id, type, coordinates | 29 | | xxx.rel | Store the relationship information between entities, i.e. the adjacency matrix. | rel_id, type, origin_id, destination_id, link_weight | 30 | | xxx.dyna | Store hourly crowd flow information. | dyna_id, type, time, entity_id, Visits | 31 | | xxx.ext | Store external time-varying information, such as weather, holidays, etc. | ext_id, time, properties[...] | 32 | | xxx.static | Store external static information, such as socioeconomics, POIs, demographics. | geo_id, properties[...] | 33 | | xxx.gbst | Store mean and std for each geo unit before the group-based z-score. | geo_id, mean, std | 34 | | config.json | Used to supplement the description of the above table information. | | 35 | 36 | The .dyna files are retrieved from [SafeGraph](https://www.safegraph.com/) using the Weekly Places Patterns Dataset. 37 | Run codes at `./data_prepare` to prepare the data and transfer it to the required format if you own the access to the raw Weekly Places Patterns Dataset. 38 | 39 | [//]: # (We don't have the permission to share the data but you can request it via the Safegraph website.) 40 | 41 | The data statistics of two datasets are as followed: 42 | 43 | 44 | | | Washington, D.C. | Baltimore City (and surrounding counties) | 45 | |---------------------------|-----------------------------------------------|-------------------------------------------| 46 | | Date Range | 01/01/2019 - 05/31/2019 | 01/01/2019 - 05/31/2019 | 47 | | # Nodes | 237 | 403 | 48 | | # Samples | 858,888 | 1,460,472 | 49 | | Sample Rate | 1 hour | 1 hour | 50 | | Input length | 24 hours | 24 hours | 51 | | Output length | 3 hours, 6 hours, 12 hours, 24 hours | 3 hours, 6 hours, 12 hours, 24 hours | | 52 | | Mean of crowd flow | 30.169 | 14.41 | 53 | | St.d. of crowd flow | 84.023 | 29.3 | 54 | 55 | ![Data Preparation](figures/DC.png "DC") 56 | 57 | ## Code Structure 58 | For easy comparison among different models, the code and data formats follow the framework proposed by [LibCity](https://github.com/LibCity/Bigscity-LibCity): 59 | * The code for Multi-ATGCN is located at `./libcity/model/traffic_flow_prediction/MultiATGCN.py`. 60 | * The code for dataset preprocessing is located at `./libcity/data/dataset/dataset_subclass/mth_dataset.py`. 61 | * The configuration for the model is located at `./libcity/config/model/traffic_state_pred/MultiATGCN.json`. 62 | * The user-defined configuration with the highest priority is located at `./config_user.json`. 63 | 64 | [//]: # (* The code for other baselines is located at `./libcity/model/`.) 65 | 66 | :exclamation: You can also directly copy the data and our model to the LibCity environment and run. 67 | However, I suggest using the repository here since some changes are made compared with the original LibCity: 68 | * A new data format for static variables is added. 69 | * A new dataset class, the mth_dataset, is added, to support multi-head temporal fusion across all models. 70 | * Support group-based normalization in model evaluation. 71 | * Support separate inclusion of time-varying external variables and time-varying calendar variables. 72 | * Configurations of model, data, and executor are changed accordingly to fit our dataset. 73 | 74 | [//]: # (* Only those with performance greater than vanilla RNN are selected from LibCity as baselines in our study.) 75 | 76 | ## Model Training 77 | The script `run_model.py` is used for training and evaluating the main model: 78 | ```bash 79 | # DC 80 | python run_model.py --task traffic_state_pred --dataset 201901010601_DC_SG_CTractFIPS_Hourly_Single_GP 81 | 82 | # Baltimore 83 | python run_model.py --task traffic_state_pred --dataset 201901010601_BM_SG_CTractFIPS_Hourly_Single_GP 84 | ``` 85 | 86 | The script `run_model_parameters.py` is used for parameter study and ablation study. Change the parameters you are interested and run: 87 | 88 | ```bash 89 | # DC 90 | python run_model_parameters.py --task traffic_state_pred --dataset 201901010601_DC_SG_CTractFIPS_Hourly_Single_GP 91 | 92 | # Baltimore 93 | python run_model_parameters.py --task traffic_state_pred --dataset 201901010601_BM_SG_CTractFIPS_Hourly_Single_GP 94 | ``` 95 | 96 | If you are using Google Colab, we also provide a notebook named `Colab_run_model.ipynb` to execute in Colab environment. 97 | Clone all files to your drive and execute the code blocks successively. 98 | 99 | ### Results 100 | See details in our paper. 101 | #### Top and last three census tracts (Baltimore) 102 | ![Top and last three census tracts (Baltimore)](figures/topbott__BM.png "Top and last three census tracts' forecasting results") 103 | Top and last three census tracts' forecasting results 104 | 105 | #### Top and last three census tracts (DC) 106 | ![Top and last three census tracts (DC)](figures/topbott__DC.png "Top and last three census tracts' forecasting results") 107 | Top and last three census tracts' forecasting results 108 | -------------------------------------------------------------------------------- /config_user.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_epoch": 30, 3 | "train_rate": 0.7, 4 | "eval_rate": 0.15, 5 | "input_window": 24, 6 | "output_window": 24, 7 | "load_external": true, 8 | "load_dynamic": false, 9 | "add_time_in_day": true, 10 | "add_day_in_week": false, 11 | "groupstd": true, 12 | "min_s": 1e-4, 13 | "hour_each_day": 24, 14 | "interval_period": 7, 15 | "interval_trend": 28, 16 | "len_closeness": 2, 17 | "len_period": 1, 18 | "len_trend": 1, 19 | "use_early_stop": true, 20 | "patience": 6 21 | } -------------------------------------------------------------------------------- /data_prepare/1.2-data_external_variable.py: -------------------------------------------------------------------------------- 1 | ######################################################################## 2 | # Prepare external information: weather, holiday, 3 | # Static variables (POI types, socio-economics, demographics, etc) 4 | ######################################################################## 5 | import pandas as pd 6 | import numpy as np 7 | import os 8 | import datetime 9 | import glob 10 | import geopandas as gpd 11 | 12 | pd.options.mode.chained_assignment = None 13 | results_path = r'D:\\ST_Graph\\Data\\' 14 | 15 | # Get county subdivision 16 | CTS_Info = pd.read_pickle(r'D:\ST_Graph\Results\CTS_Info.pkl') 17 | t_s = datetime.datetime(2019, 1, 1) 18 | t_e = datetime.datetime(2020, 12, 1) 19 | 20 | # Ext: add weather 21 | # Station in BMC 22 | g_stat = pd.read_pickle(r'E:\Weather\Daily\weather_raw_2019.pkl') 23 | g_stat = g_stat[['STATION', 'LATITUDE', 'LONGITUDE']] 24 | g_stat = g_stat.drop_duplicates(subset=['STATION']).dropna() 25 | g_stat_s = gpd.GeoDataFrame(g_stat, geometry=gpd.points_from_xy(g_stat['LONGITUDE'], g_stat['LATITUDE'])) 26 | ghcnd_station_s = g_stat_s.set_crs('EPSG:4326') 27 | SInBMC = gpd.sjoin(ghcnd_station_s, CTS_Info, how='inner', op='within').reset_index(drop=True) 28 | SInBMC = SInBMC[['STATION', 'LATITUDE', 'LONGITUDE', 'CTSFIPS']] 29 | 30 | # Read weather data 31 | afiles = glob.glob(r'E:\Weather\Hourly\2019\*.csv') + glob.glob(r'E:\Weather\Hourly\2020\*.csv') 32 | nlist = list(SInBMC['STATION'].astype(str)) 33 | nfiles = [i for e in nlist for i in afiles if e in i] 34 | hourly_wea = pd.concat(map(pd.read_csv, nfiles)).reset_index(drop=True) 35 | hourly_wea = hourly_wea.merge(SInBMC[['STATION', 'CTSFIPS']], on='STATION') 36 | 37 | # Only need wnd (m/s), tmp (Celcius), vis ( horizontal distance meter), AA1 (rain millimeters), AJ1 (snow centimeters) 38 | # https://www.visualcrossing.com/resources/documentation/weather-data/how-we-process-integrated-surface-database-historical-weather-data/ 39 | # https://www.ncei.noaa.gov/data/global-hourly/doc/isd-format-document.pdf 40 | # https://www.ncei.noaa.gov/data/global-hourly/doc/CSV_HELP.pdf 41 | hourly_wea['vis'] = hourly_wea['VIS'].str.split(',').str[0].astype(float) # m 42 | hourly_wea['wind'] = hourly_wea['WND'].str.split(',').str[3].astype(float) * 0.1 # m/s 43 | hourly_wea['temp'] = hourly_wea['TMP'].str.split(',').str[0].astype(float) * 0.1 # Celcius 44 | hourly_wea['rain'] = hourly_wea['AA1'].str.split(',').str[1].astype(float) * 0.1 # millimeters 45 | hourly_wea['snow'] = hourly_wea['AJ1'].str.split(',').str[0].astype(float) * 10 # millimeters 46 | hourly_wea = hourly_wea[['STATION', 'CTSFIPS', 'DATE', 'LATITUDE', 'LONGITUDE', 'wind', 'temp', 'rain', 'snow', 'vis']] 47 | hourly_wea['DATE'] = pd.to_datetime(hourly_wea['DATE']).dt.round('H') 48 | hourly_wea.describe().T[['count', 'mean', 'min', 'max']] 49 | # Handle outliers: 50 | hourly_wea.loc[hourly_wea['temp'] < -25, 'temp'] = np.nan 51 | for kk in ['wind', 'temp', 'rain', 'vis']: 52 | hourly_wea = hourly_wea.replace(hourly_wea[kk].max(), np.nan) 53 | hourly_wea.describe().T[['count', 'mean', 'min', 'max']] 54 | # Fillna 55 | hourly_wea['rain'] = hourly_wea['rain'].fillna(0) 56 | hourly_wea['snow'] = hourly_wea['snow'].fillna(0) 57 | for kk in ['wind', 'temp', 'vis']: 58 | hourly_wea[kk] = hourly_wea[kk].fillna(hourly_wea.groupby('DATE')[kk].transform('median')) 59 | # group mean 60 | hourly_wea_mean = hourly_wea.groupby(['DATE']).mean().reset_index() 61 | hourly_wea_mean = hourly_wea_mean[(hourly_wea_mean['DATE'] < t_e) & (hourly_wea_mean['DATE'] >= t_s)].reset_index( 62 | drop=True) 63 | hourly_wea_mean[['DATE', 'wind', 'temp', 'rain', 'snow', 'vis']].to_pickle(r'D:\ST_Graph\Results\weather_2019_bmc.pkl') 64 | # plt.plot(hourly_wea_mean['rain']) 65 | # hourly_wea.groupby(['CTSFIPS', 'DATE']).mean()['rain'].plot() 66 | 67 | # Add socio-economic data 68 | # POT INFO 69 | BMCPOI = pd.read_pickle(r'D:\ST_Graph\Results\BMCPOI_0922.pkl') 70 | CBG_CTS = pd.read_pickle(r'D:\ST_Graph\Results\CBG_CTS.pkl') 71 | for sunit in ['CTSFIPS', 'CBGFIPS', 'CTractFIPS']: # 'CBGFIPS', 'CTractFIPS' 72 | BMCPOI_count = BMCPOI.groupby([sunit, 'top_category']).count()['safegraph_place_id'].reset_index() 73 | BMCPOI_count = BMCPOI_count.pivot(index=sunit, columns='top_category', values='safegraph_place_id').reset_index() 74 | BMCPOI_count = BMCPOI_count.fillna(0) 75 | 76 | # Income etc. 77 | CBG_Features = pd.read_csv(r'E:\Research\COVID19-Socio\Data\CBG_COVID_19.csv', index_col=0) 78 | CBG_Features['CBGFIPS'] = CBG_Features['BGFIPS'].astype(str).apply(lambda x: x.zfill(12)) 79 | CBG_Features = CBG_Features.merge(CBG_CTS, on='CBGFIPS') 80 | CBG_Features['CTractFIPS'] = CBG_Features['CBGFIPS'].str[0:11] 81 | 82 | CTS_SUM_POP = CBG_Features.groupby([sunit]).sum()[['Total_Population', 'ALAND']].reset_index() 83 | CTS_SUM_POP.columns = [sunit, 'Total_Population_' + sunit, 'ALAND_' + sunit] 84 | CBG_Features = CBG_Features.merge(CTS_SUM_POP, on=sunit) 85 | 86 | # To abs and then covert to pct 87 | abslist = ['Median_income', 'Democrat_R', 'Republican_R', 'Urbanized_Areas_Population_R', 'HISPANIC_LATINO_R', 88 | 'Black_R', 'Asian_R', 'Bt_18_44_R', 'Bt_45_64_R', 'Over_65_R', 'Male_R', 'White_Non_Hispanic_R', 89 | 'White_Hispanic_R', 'Education_Degree_R'] 90 | for kk in abslist: CBG_Features[kk] = CBG_Features[kk] * CBG_Features['Total_Population'] 91 | CBG_Features_sum = CBG_Features.groupby([sunit]).sum()[abslist].reset_index() 92 | CBG_Features_sum = CBG_Features_sum.merge(CTS_SUM_POP, on=sunit) 93 | for kk in abslist: CBG_Features_sum[kk] = CBG_Features_sum[kk] / CBG_Features_sum['Total_Population_' + sunit] 94 | 95 | CTS_Socio = CBG_Features_sum[[sunit] + abslist] 96 | CTS_Socio = CTS_Socio.merge(CTS_SUM_POP, on=sunit) 97 | CTS_Socio = CTS_Socio.merge(BMCPOI_count, on=sunit) 98 | CTS_Socio = CTS_Socio.sort_values(by=[sunit]).reset_index(drop=True) 99 | CTS_Socio.to_pickle(r'D:\ST_Graph\Results\%s_Socio_bmc.pkl' % sunit) 100 | for kk in list(CTS_Socio.columns)[1:]: CTS_Socio[kk] = (CTS_Socio[kk] - CTS_Socio[kk].mean()) / CTS_Socio[kk].std() 101 | CTS_Socio.rename({sunit: 'geo_id'}, axis=1, inplace=True) 102 | CTS_Socio.to_csv(r'D:\ST_Graph\Results\%s_Hourly_GP.static' % sunit, index=0) 103 | -------------------------------------------------------------------------------- /figures/DC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SonghuaHu-UMD/MultiSTGraph/96fab1754905336cca08cc83200b6e81f3aa6c43/figures/DC.png -------------------------------------------------------------------------------- /figures/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SonghuaHu-UMD/MultiSTGraph/96fab1754905336cca08cc83200b6e81f3aa6c43/figures/framework.png -------------------------------------------------------------------------------- /figures/topbott__BM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SonghuaHu-UMD/MultiSTGraph/96fab1754905336cca08cc83200b6e81f3aa6c43/figures/topbott__BM.png -------------------------------------------------------------------------------- /figures/topbott__DC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SonghuaHu-UMD/MultiSTGraph/96fab1754905336cca08cc83200b6e81f3aa6c43/figures/topbott__DC.png -------------------------------------------------------------------------------- /libcity/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SonghuaHu-UMD/MultiSTGraph/96fab1754905336cca08cc83200b6e81f3aa6c43/libcity/__init__.py -------------------------------------------------------------------------------- /libcity/config/__init__.py: -------------------------------------------------------------------------------- 1 | from libcity.config.config_parser import ConfigParser 2 | 3 | __all__ = [ 4 | 'ConfigParser' 5 | ] 6 | -------------------------------------------------------------------------------- /libcity/config/config_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | 5 | 6 | class ConfigParser(object): 7 | """ 8 | use to parse the user defined parameters and use these to modify the 9 | pipeline's parameter setting. 10 | 值得注意的是,目前各阶段的参数是放置于同一个 dict 中的,因此需要编程时保证命名空间不冲突。 11 | config 优先级:命令行 > config file > default config 12 | """ 13 | 14 | def __init__(self, task, model, dataset, config_file=None, 15 | saved_model=True, train=True, other_args=None, hyper_config_dict=None): 16 | """ 17 | Args: 18 | task, model, dataset (str): 用户在命令行必须指明的三个参数 19 | config_file (str): 配置文件的文件名,将在项目根目录下进行搜索 20 | other_args (dict): 通过命令行传入的其他参数 21 | """ 22 | self.config = {} 23 | self._parse_external_config(task, model, dataset, saved_model, train, other_args, hyper_config_dict) 24 | self._parse_config_file(config_file) 25 | self._load_default_config() 26 | self._init_device() 27 | 28 | def _parse_external_config(self, task, model, dataset, 29 | saved_model=True, train=True, other_args=None, hyper_config_dict=None): 30 | if task is None: 31 | raise ValueError('the parameter task should not be None!') 32 | if model is None: 33 | raise ValueError('the parameter model should not be None!') 34 | if dataset is None: 35 | raise ValueError('the parameter dataset should not be None!') 36 | # 目前暂定这三个参数必须由用户指定 37 | self.config['task'] = task 38 | self.config['model'] = model 39 | self.config['dataset'] = dataset 40 | self.config['saved_model'] = saved_model 41 | self.config['train'] = False if task == 'map_matching' else train 42 | if other_args is not None: 43 | # TODO: 这里可以设计加入参数检查,哪些参数是允许用户通过命令行修改的 44 | for key in other_args: 45 | self.config[key] = other_args[key] 46 | if hyper_config_dict is not None: 47 | # 超参数调整时传入的待调整的参数,优先级低于命令行参数 48 | for key in hyper_config_dict: 49 | self.config[key] = hyper_config_dict[key] 50 | 51 | def _parse_config_file(self, config_file): 52 | if config_file is not None: 53 | # TODO: 对 config file 的格式进行检查 54 | if os.path.exists('./{}.json'.format(config_file)): 55 | with open('./{}.json'.format(config_file), 'r') as f: 56 | x = json.load(f) 57 | for key in x: 58 | if key not in self.config: 59 | self.config[key] = x[key] 60 | else: 61 | raise FileNotFoundError( 62 | 'Config file {}.json is not found. Please ensure \ 63 | the config file is in the root dir and is a JSON \ 64 | file.'.format(config_file)) 65 | 66 | def _load_default_config(self): 67 | # 首先加载 task config 68 | with open('./libcity/config/task_config.json', 'r') as f: 69 | task_config = json.load(f) 70 | if self.config['task'] not in task_config: 71 | raise ValueError( 72 | 'task {} is not supported.'.format(self.config['task'])) 73 | task_config = task_config[self.config['task']] 74 | # check model and dataset 75 | if self.config['model'] not in task_config['allowed_model']: 76 | raise ValueError('task {} do not support model {}'.format( 77 | self.config['task'], self.config['model'])) 78 | model = self.config['model'] 79 | # 加载 dataset、executor、evaluator 的模块 80 | if 'dataset_class' not in self.config: 81 | self.config['dataset_class'] = task_config[model]['dataset_class'] 82 | if self.config['task'] == 'traj_loc_pred' and 'traj_encoder' not in self.config: 83 | self.config['traj_encoder'] = task_config[model]['traj_encoder'] 84 | if self.config['task'] == 'eta' and 'eta_encoder' not in self.config: 85 | self.config['eta_encoder'] = task_config[model]['eta_encoder'] 86 | if 'executor' not in self.config: 87 | self.config['executor'] = task_config[model]['executor'] 88 | if 'evaluator' not in self.config: 89 | self.config['evaluator'] = task_config[model]['evaluator'] 90 | # 对于 LSTM RNN GRU 使用的都是同一个类,只是 RNN 模块不一样而已,这里做一下修改 91 | if self.config['model'].upper() in ['LSTM', 'GRU', 'RNN']: 92 | self.config['rnn_type'] = self.config['model'] 93 | self.config['model'] = 'RNN' 94 | # if self.config['dataset'] not in task_config['allowed_dataset']: 95 | # raise ValueError('task {} do not support dataset {}'.format( 96 | # self.config['task'], self.config['dataset'])) 97 | # 接着加载每个阶段的 default config 98 | default_file_list = [] 99 | # model 100 | default_file_list.append('model/{}/{}.json'.format(self.config['task'], self.config['model'])) 101 | # dataset 102 | default_file_list.append('data/{}.json'.format(self.config['dataset_class'])) 103 | # executor 104 | default_file_list.append('executor/{}.json'.format(self.config['executor'])) 105 | # evaluator 106 | default_file_list.append('evaluator/{}.json'.format(self.config['evaluator'])) 107 | # 加载所有默认配置 108 | for file_name in default_file_list: 109 | with open('./libcity/config/{}'.format(file_name), 'r') as f: 110 | x = json.load(f) 111 | for key in x: 112 | if key not in self.config: 113 | self.config[key] = x[key] 114 | # 加载数据集config.json 115 | with open('./raw_data/{}/config.json'.format(self.config['dataset']), 'r') as f: 116 | x = json.load(f) 117 | for key in x: 118 | if key == 'info': 119 | for ik in x[key]: 120 | if ik not in self.config: 121 | self.config[ik] = x[key][ik] 122 | else: 123 | if key not in self.config: 124 | self.config[key] = x[key] 125 | 126 | def _init_device(self): 127 | use_gpu = self.config.get('gpu', True) 128 | gpu_id = self.config.get('gpu_id', 0) 129 | if use_gpu: 130 | torch.cuda.set_device(gpu_id) 131 | self.config['device'] = torch.device( 132 | "cuda:%d" % gpu_id if torch.cuda.is_available() and use_gpu else "cpu") 133 | 134 | def get(self, key, default=None): 135 | return self.config.get(key, default) 136 | 137 | def __getitem__(self, key): 138 | if key in self.config: 139 | return self.config[key] 140 | else: 141 | raise KeyError('{} is not in the config'.format(key)) 142 | 143 | def __setitem__(self, key, value): 144 | self.config[key] = value 145 | 146 | def __contains__(self, key): 147 | return key in self.config 148 | 149 | # 支持迭代操作 150 | def __iter__(self): 151 | return self.config.__iter__() 152 | -------------------------------------------------------------------------------- /libcity/config/data/MTHDataset.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 64, 3 | "cache_dataset": true, 4 | "num_workers": 0, 5 | "pad_with_last_sample": true, 6 | "train_rate": 0.7, 7 | "eval_rate": 0.1, 8 | "scaler": "standard", 9 | "load_external": false, 10 | "normal_external": false, 11 | "ext_scaler": "none", 12 | "input_window": 12, 13 | "output_window": 12, 14 | "add_time_in_day": false, 15 | "add_day_in_week": false, 16 | "len_closeness": 1, 17 | "len_period": 1, 18 | "len_trend": 2, 19 | "interval_period": 1, 20 | "interval_trend": 7 21 | } 22 | -------------------------------------------------------------------------------- /libcity/config/data/TrafficStateDataset.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 64, 3 | "cache_dataset": true, 4 | "num_workers": 0, 5 | "pad_with_last_sample": true, 6 | "train_rate": 0.7, 7 | "eval_rate": 0.1, 8 | "scaler": "none", 9 | "load_external": false, 10 | "normal_external": false, 11 | "ext_scaler": "none", 12 | "input_window": 12, 13 | "output_window": 12, 14 | "add_time_in_day": false, 15 | "add_day_in_week": false 16 | } 17 | -------------------------------------------------------------------------------- /libcity/config/data/TrafficStatePointDataset.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 64, 3 | "cache_dataset": true, 4 | "num_workers": 0, 5 | "pad_with_last_sample": true, 6 | "train_rate": 0.7, 7 | "eval_rate": 0.1, 8 | "scaler": "none", 9 | "load_external": false, 10 | "normal_external": false, 11 | "ext_scaler": "none", 12 | "input_window": 12, 13 | "output_window": 12, 14 | "add_time_in_day": false, 15 | "add_day_in_week": false 16 | } 17 | -------------------------------------------------------------------------------- /libcity/config/evaluator/TrafficStateEvaluator.json: -------------------------------------------------------------------------------- 1 | { 2 | "metrics": ["MAE", "MAPE", "MSE", "RMSE", "masked_MAE", "masked_MAPE", "masked_MSE", "masked_RMSE", "R2", "EVAR"], 3 | "evaluator_mode": "single", 4 | "save_mode": ["csv"] 5 | } 6 | -------------------------------------------------------------------------------- /libcity/config/executor/TrafficStateExecutor.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpu": true, 3 | "gpu_id": 0, 4 | "max_epoch": 100, 5 | "train_loss": "none", 6 | "epoch": 0, 7 | "learner": "adam", 8 | "learning_rate": 0.01, 9 | "weight_decay": 0, 10 | "lr_epsilon": 1e-8, 11 | "lr_beta1": 0.9, 12 | "lr_beta2": 0.999, 13 | "lr_alpha": 0.99, 14 | "lr_momentum": 0, 15 | "lr_decay": false, 16 | "lr_scheduler": "multisteplr", 17 | "lr_decay_ratio": 0.1, 18 | "steps": [5, 20, 40, 70], 19 | "step_size": 10, 20 | "lr_T_max": 30, 21 | "lr_eta_min": 0, 22 | "lr_patience": 10, 23 | "lr_threshold": 1e-4, 24 | "clip_grad_norm": false, 25 | "max_grad_norm": 1.0, 26 | "use_early_stop": false, 27 | "patience": 50, 28 | "log_level": "INFO", 29 | "log_every": 1, 30 | "saved_model": true, 31 | "load_best_epoch": true, 32 | "hyper_tune": false 33 | } 34 | -------------------------------------------------------------------------------- /libcity/config/model/traffic_state_pred/MultiATGCN.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim_node": 20, 3 | "embed_dim_adj": 20, 4 | "rnn_units": 64, 5 | "num_layers": 2, 6 | "cheb_order": 2, 7 | "use_3tu": true, 8 | "node_specific_off": false, 9 | "gcn_off": false, 10 | "fnn_off": false, 11 | "bidir_adj_mx": false, 12 | "batch_size": 16, 13 | "adpadj": "none", 14 | "adjtype": "cosine", 15 | "scaler": "standard", 16 | "add_static": false, 17 | "ext_scaler": "none", 18 | "learner": "adam", 19 | "learning_rate": 0.003, 20 | "lr_decay": true, 21 | "lr_scheduler": "multisteplr", 22 | "lr_decay_ratio": 0.75, 23 | "steps": [ 24 | 5, 25 | 10, 26 | 20, 27 | 30 28 | ], 29 | "clip_grad_norm": true, 30 | "max_grad_norm": 5 31 | } 32 | -------------------------------------------------------------------------------- /libcity/config/task_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "traffic_state_pred": { 3 | "allowed_model": [ 4 | "DCRNN", 5 | "STGCN", 6 | "GWNET", 7 | "AGCRN", 8 | "TGCLSTM", 9 | "TGCN", 10 | "TemplateTSP", 11 | "ASTGCN", 12 | "MSTGCN", 13 | "MTGNN", 14 | "ACFM", 15 | "STResNet", 16 | "RNN", 17 | "LSTM", 18 | "GRU", 19 | "AutoEncoder", 20 | "Seq2Seq", 21 | "STResNetCommon", 22 | "ACFMCommon", 23 | "ASTGCNCommon", 24 | "MSTGCNCommon", 25 | "ToGCN", 26 | "CONVGCN", 27 | "STG2Seq", 28 | "DMVSTNet", 29 | "ATDM", 30 | "GMAN", 31 | "GTS", 32 | "STDN", 33 | "HGCN", 34 | "STSGCN", 35 | "STAGGCN", 36 | "STNN", 37 | "ResLSTM", 38 | "DGCN", 39 | "MultiSTGCnet", 40 | "STMGAT", 41 | "CRANN", 42 | "STTN", 43 | "CONVGCNCommon", 44 | "DSAN", 45 | "DKFN", 46 | "CCRNN", 47 | "MultiSTGCnetCommon", 48 | "GEML", 49 | "FNN", 50 | "GSNet", 51 | "CSTN", 52 | "MultiATGCN", 53 | "MultiATGCN3UT" 54 | ], 55 | "allowed_dataset": [ 56 | "METR_LA", 57 | "PEMS_BAY", 58 | "PEMSD3", 59 | "PEMSD4", 60 | "PEMSD7", 61 | "PEMSD8", 62 | "PEMSD7(M)", 63 | "LOOP_SEATTLE", 64 | "LOS_LOOP", 65 | "LOS_LOOP_SMALL", 66 | "Q_TRAFFIC", 67 | "SZ_TAXI", 68 | "NYCBike20140409", 69 | "NYCBike20160708", 70 | "NYCBike20160809", 71 | "NYCTaxi20140112", 72 | "NYCTaxi20150103", 73 | "NYCTaxi20160102", 74 | "TAXIBJ", 75 | "T_DRIVE20150206", 76 | "BEIJING_SUBWAY_10MIN", 77 | "BEIJING_SUBWAY_15MIN", 78 | "BEIJING_SUBWAY_30MIN", 79 | "ROTTERDAM", 80 | "HZMETRO", 81 | "SHMETRO", 82 | "M_DENSE", 83 | "PORTO", 84 | "NYCTAXI_DYNA", 85 | "NYCTAXI_OD", 86 | "NYCTAXI_GRID", 87 | "T_DRIVE_SMALL", 88 | "NYCBIKE", 89 | "AUSTINRIDE", 90 | "BIKEDC", 91 | "BIKECHI", 92 | "NYC_RISK", 93 | "CHICAGO_RISK" 94 | ], 95 | "DCRNN": { 96 | "dataset_class": "MTHDataset", 97 | "executor": "DCRNNExecutor", 98 | "evaluator": "TrafficStateEvaluator" 99 | }, 100 | "STGCN": { 101 | "dataset_class": "MTHDataset", 102 | "executor": "TrafficStateExecutor", 103 | "evaluator": "TrafficStateEvaluator" 104 | }, 105 | "GWNET": { 106 | "dataset_class": "MTHDataset", 107 | "executor": "TrafficStateExecutor", 108 | "evaluator": "TrafficStateEvaluator" 109 | }, 110 | "AGCRN": { 111 | "dataset_class": "MTHDataset", 112 | "executor": "TrafficStateExecutor", 113 | "evaluator": "TrafficStateEvaluator" 114 | }, 115 | "MultiATGCN": { 116 | "dataset_class": "MTHDataset", 117 | "executor": "TrafficStateExecutor", 118 | "evaluator": "TrafficStateEvaluator" 119 | }, 120 | "MultiATGCN3UT": { 121 | "dataset_class": "MTHDataset", 122 | "executor": "TrafficStateExecutor", 123 | "evaluator": "TrafficStateEvaluator" 124 | }, 125 | "TGCN": { 126 | "dataset_class": "MTHDataset", 127 | "executor": "TrafficStateExecutor", 128 | "evaluator": "TrafficStateEvaluator" 129 | }, 130 | "ASTGCN": { 131 | "dataset_class": "ASTGCNDataset", 132 | "executor": "TrafficStateExecutor", 133 | "evaluator": "TrafficStateEvaluator" 134 | }, 135 | "MTGNN": { 136 | "dataset_class": "MTHDataset", 137 | "executor": "MTGNNExecutor", 138 | "evaluator": "TrafficStateEvaluator" 139 | }, 140 | "RNN": { 141 | "dataset_class": "MTHDataset", 142 | "executor": "TrafficStateExecutor", 143 | "evaluator": "TrafficStateEvaluator" 144 | }, 145 | "LSTM": { 146 | "dataset_class": "MTHDataset", 147 | "executor": "TrafficStateExecutor", 148 | "evaluator": "TrafficStateEvaluator" 149 | }, 150 | "GRU": { 151 | "dataset_class": "MTHDataset", 152 | "executor": "TrafficStateExecutor", 153 | "evaluator": "TrafficStateEvaluator" 154 | }, 155 | "GMAN": { 156 | "dataset_class": "GMANDataset", 157 | "executor": "TrafficStateExecutor", 158 | "evaluator": "TrafficStateEvaluator" 159 | }, 160 | "FNN": { 161 | "dataset_class": "MTHDataset", 162 | "executor": "TrafficStateExecutor", 163 | "evaluator": "TrafficStateEvaluator" 164 | } 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /libcity/data/__init__.py: -------------------------------------------------------------------------------- 1 | from libcity.data.utils import get_dataset 2 | 3 | __all__ = [ 4 | "get_dataset" 5 | ] 6 | -------------------------------------------------------------------------------- /libcity/data/batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class Batch(object): 6 | 7 | def __init__(self, feature_name): 8 | """Summary of class here 9 | 10 | Args: 11 | feature_name (dict): key is the corresponding feature's name, and 12 | the value is the feature's data type 13 | """ 14 | self.data = {} 15 | self.feature_name = feature_name 16 | for key in feature_name: 17 | self.data[key] = [] 18 | 19 | def __getitem__(self, key): 20 | if key in self.data: 21 | return self.data[key] 22 | else: 23 | raise KeyError('{} is not in the batch'.format(key)) 24 | 25 | def __setitem__(self, key, value): 26 | if key in self.data: 27 | self.data[key] = value 28 | else: 29 | raise KeyError('{} is not in the batch'.format(key)) 30 | 31 | def append(self, item): 32 | """ 33 | append a new item into the batch 34 | 35 | Args: 36 | item (list): 一组输入,跟feature_name的顺序一致,feature_name即是这一组输入的名字 37 | """ 38 | if len(item) != len(self.feature_name): 39 | raise KeyError('when append a batch, item is not equal length with feature_name') 40 | for i, key in enumerate(self.feature_name): 41 | self.data[key].append(item[i]) 42 | 43 | def to_tensor(self, device): 44 | """ 45 | 将数据self.data转移到device上 46 | 47 | Args: 48 | device(torch.device): GPU/CPU设备 49 | """ 50 | for key in self.data: 51 | if self.feature_name[key] == 'int': 52 | self.data[key] = torch.LongTensor(np.array(self.data[key])).to(device) 53 | elif self.feature_name[key] == 'float': 54 | self.data[key] = torch.FloatTensor(np.array(self.data[key])).to(device) 55 | else: 56 | raise TypeError( 57 | 'Batch to_tensor, only support int, float but you give {}'.format(self.feature_name[key])) 58 | 59 | def to_ndarray(self): 60 | for key in self.data: 61 | if self.feature_name[key] == 'int': 62 | self.data[key] = np.array(self.data[key]) 63 | elif self.feature_name[key] == 'float': 64 | self.data[key] = np.array(self.data[key]) 65 | else: 66 | raise TypeError( 67 | 'Batch to_ndarray, only support int, float but you give {}'.format(self.feature_name[key])) 68 | 69 | 70 | class BatchPAD(Batch): 71 | 72 | def __init__(self, feature_name, pad_item=None, pad_max_len=None): 73 | """Summary of class here 74 | 75 | Args: 76 | feature_name (dict): key is the corresponding feature's name, and 77 | the value is the feature's data type 78 | pad_item (dict): key is the feature name, and value is the padding 79 | value. We will just padding the feature in pad_item 80 | pad_max_len (dict): key is the feature name, and value is the max 81 | length of padded feature. use this parameter to truncate the 82 | feature. 83 | """ 84 | super().__init__(feature_name=feature_name) 85 | # 默认是根据 batch 中每个特征最长的长度来补齐,如果某个特征的长度超过了 pad_max_len 则进行剪切 86 | self.pad_len = {} 87 | self.origin_len = {} # 用于得知补齐前轨迹的原始长度 88 | self.pad_max_len = pad_max_len if pad_max_len is not None else {} 89 | self.pad_item = pad_item if pad_item is not None else {} 90 | for key in feature_name: 91 | self.data[key] = [] 92 | if key in self.pad_item: 93 | self.pad_len[key] = 0 94 | self.origin_len[key] = [] 95 | 96 | def append(self, item): 97 | """ 98 | append a new item into the batch 99 | 100 | Args: 101 | item (list): 一组输入,跟feature_name的顺序一致,feature_name即是这一组输入的名字 102 | """ 103 | if len(item) != len(self.feature_name): 104 | raise KeyError('when append a batch, item is not equal length with feature_name') 105 | for i, key in enumerate(self.feature_name): 106 | # 需保证 item 每个特征的顺序与初始化时传入的 feature_name 中特征的顺序一致 107 | self.data[key].append(item[i]) 108 | if key in self.pad_item: 109 | self.origin_len[key].append(len(item[i])) 110 | if self.pad_len[key] < len(item[i]): 111 | # 保持 pad_len 是最大的 112 | self.pad_len[key] = len(item[i]) 113 | 114 | def padding(self): 115 | """ 116 | 只提供对一维数组的特征进行补齐 117 | """ 118 | for key in self.pad_item: 119 | # 只对在 pad_item 中的特征进行补齐 120 | if key not in self.data: 121 | raise KeyError('when pad a batch, raise this error!') 122 | max_len = self.pad_len[key] 123 | if key in self.pad_max_len: 124 | max_len = min(self.pad_max_len[key], max_len) 125 | for i in range(len(self.data[key])): 126 | if len(self.data[key][i]) < max_len: 127 | self.data[key][i] += [self.pad_item[key]] * \ 128 | (max_len - len(self.data[key][i])) 129 | else: 130 | # 截取的原则是,抛弃前面的点 131 | # 因为是时间序列嘛 132 | self.data[key][i] = self.data[key][i][-max_len:] 133 | # 对于剪切了的,我们没办法还原,但至少不要使他出错 134 | self.origin_len[key][i] = max_len 135 | 136 | def get_origin_len(self, key): 137 | return self.origin_len[key] 138 | 139 | def to_tensor(self, device): 140 | """ 141 | 将数据self.data转移到device上 142 | 143 | Args: 144 | device(torch.device): GPU/CPU设备 145 | """ 146 | for key in self.data: 147 | if self.feature_name[key] == 'int': 148 | self.data[key] = torch.LongTensor(np.array(self.data[key])).to(device) 149 | elif self.feature_name[key] == 'float': 150 | self.data[key] = torch.FloatTensor(np.array(self.data[key])).to(device) 151 | elif self.feature_name[key] == 'array of int': 152 | for i in range(len(self.data[key])): 153 | for j in range(len(self.data[key][i])): 154 | try: 155 | self.data[key][i][j] = torch.LongTensor(np.array(self.data[key][i][j])).to(device) 156 | except TypeError: 157 | print('device is ', device) 158 | exit() 159 | elif self.feature_name[key] == 'no_pad_int': 160 | for i in range(len(self.data[key])): 161 | self.data[key][i] = torch.LongTensor(np.array(self.data[key][i])).to(device) 162 | elif self.feature_name[key] == 'no_pad_float': 163 | for i in range(len(self.data[key])): 164 | self.data[key][i] = torch.FloatTensor(np.array(self.data[key][i])).to(device) 165 | elif self.feature_name[key] == 'no_tensor': 166 | pass 167 | else: 168 | raise TypeError( 169 | 'Batch to_tensor, only support int, float but you give {}'.format(self.feature_name[key])) 170 | -------------------------------------------------------------------------------- /libcity/data/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from libcity.data.dataset.abstract_dataset import AbstractDataset 2 | from libcity.data.dataset.traffic_state_datatset import TrafficStateDataset 3 | from libcity.data.dataset.traffic_state_point_dataset import TrafficStatePointDataset 4 | 5 | __all__ = [ 6 | "AbstractDataset", 7 | "TrafficStateDataset", 8 | "TrafficStatePointDataset", 9 | ] 10 | -------------------------------------------------------------------------------- /libcity/data/dataset/abstract_dataset.py: -------------------------------------------------------------------------------- 1 | class AbstractDataset(object): 2 | 3 | def __init__(self, config): 4 | raise NotImplementedError("Dataset not implemented") 5 | 6 | def get_data(self): 7 | """ 8 | 返回数据的DataLoader,包括训练数据、测试数据、验证数据 9 | 10 | Returns: 11 | tuple: tuple contains: 12 | train_dataloader: Dataloader composed of Batch (class) \n 13 | eval_dataloader: Dataloader composed of Batch (class) \n 14 | test_dataloader: Dataloader composed of Batch (class) 15 | """ 16 | raise NotImplementedError("get_data not implemented") 17 | 18 | def get_data_feature(self): 19 | """ 20 | 返回一个 dict,包含数据集的相关特征 21 | 22 | Returns: 23 | dict: 包含数据集的相关特征的字典 24 | """ 25 | raise NotImplementedError("get_data_feature not implemented") 26 | -------------------------------------------------------------------------------- /libcity/data/dataset/dataset_subclass/__init__.py: -------------------------------------------------------------------------------- 1 | from libcity.data.dataset.dataset_subclass.mth_dataset import MTHDataset 2 | 3 | __all__ = [ 4 | 5 | "MTHDataset" 6 | ] 7 | -------------------------------------------------------------------------------- /libcity/data/dataset/dataset_subclass/mth_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | 5 | from libcity.data.dataset import TrafficStatePointDataset 6 | 7 | 8 | class MTHDataset(TrafficStatePointDataset): 9 | 10 | def __init__(self, config): 11 | super().__init__(config) 12 | self.points_per_hour = 3600 // self.time_intervals # 每小时的时间片数 13 | self.len_closeness = self.config.get('len_closeness', 3) 14 | self.len_period = self.config.get('len_period', 4) 15 | self.len_trend = self.config.get('len_trend', 0) 16 | assert (self.len_closeness + self.len_period + self.len_trend > 0) 17 | self.interval_period = self.config.get('interval_period', 1) # period的长度/天 18 | self.interval_trend = self.config.get('interval_trend', 7) # trend的长度/天 19 | self.feature_name = {'X': 'float', 'y': 'float'} 20 | self.hour_each_day = self.config.get('hour_each_day', 24) # hours contained in a day 21 | self.parameters_str = \ 22 | str(self.dataset) + '_' + str(self.len_closeness) + '_' + str(self.len_period) + '_' + str( 23 | self.len_trend) + '_' + str(self.interval_period) + '_' + str(self.interval_trend) + '_' + str( 24 | self.input_window) + '_' + str(self.output_window) + '_' + str(self.train_rate) + '_' + str( 25 | self.eval_rate) + '_' + str(self.scaler_type) + '_' + str(self.batch_size) + '_' + str( 26 | self.load_external) + '_' + str(self.load_dynamic) + '_' + str(self.add_time_in_day) + '_' + str( 27 | self.add_day_in_week) + '_' + str(self.pad_with_last_sample) 28 | self.cache_file_name = os.path.join('./libcity/cache/dataset_cache/', 29 | 'point_based_{}.npz'.format(self.parameters_str)) 30 | 31 | def _search_data(self, sequence_length, label_start_idx, num_for_predict, num_of_depend, units): 32 | """ 33 | 根据全局参数len_closeness/len_period/len_trend找到数据索引的位置 34 | 35 | Args: 36 | sequence_length(int): 历史数据的总长度 37 | label_start_idx(int): 预测开始的时间片的索引 38 | num_for_predict(int): 预测的时间片序列长度 39 | num_of_depend(int): len_trend/len_period/len_closeness 40 | units(int): trend/period/closeness的长度(以小时为单位) 41 | 42 | Returns: 43 | list: 起点-终点区间段的数组,list[(start_idx, end_idx)] 44 | """ 45 | if self.points_per_hour < 0: 46 | raise ValueError("points_per_hour should be greater than 0!") 47 | if label_start_idx + num_for_predict > sequence_length: 48 | return None 49 | x_idx = [] 50 | for i in range(1, num_of_depend + 1): 51 | # 从label_start_idx向左偏移,i是区间数,units*points_per_hour是区间长度(时间片为单位) 52 | start_idx = label_start_idx - int(self.points_per_hour * units * i) 53 | end_idx = start_idx + num_for_predict 54 | if start_idx >= 0: 55 | x_idx.append((start_idx, end_idx)) # 每一段的长度是num_for_predict 56 | else: # i越大越可能有问题,所以遇到错误直接范湖 57 | return None 58 | if len(x_idx) != num_of_depend: 59 | return None 60 | return x_idx[::-1] # 倒序,因为原顺序是从右到左,倒序则从左至右 61 | 62 | def _get_sample_indices(self, data_sequence, label_start_idx): 63 | """ 64 | 根据全局参数len_closeness/len_period/len_trend找到数据预测目标数据 65 | 段: [label_start_idx: label_start_idx+input_window) 66 | 67 | Args: 68 | data_sequence(np.ndarray): 输入数据,shape: (len_time, ..., feature_dim) 69 | label_start_idx(int): the first index of predicting target, 预测开始的时间片的索引 70 | 71 | Returns: 72 | tuple: tuple contains: 73 | trend_sample: 输入数据1, (len_trend * self.input_window, ..., feature_dim) \n 74 | period_sample: 输入数据2, (len_period * self.input_window, ..., feature_dim) \n 75 | closeness_sample: 输入数据3, (len_closeness * self.input_window, ..., feature_dim) \n 76 | target: 输出数据, (self.input_window, ..., feature_dim) 77 | """ 78 | trend_sample, period_sample, closeness_sample = None, None, None 79 | if label_start_idx + self.input_window > data_sequence.shape[0]: 80 | return trend_sample, period_sample, closeness_sample, None 81 | 82 | if self.len_trend > 0: 83 | trend_indices = self._search_data(data_sequence.shape[0], label_start_idx, self.input_window, 84 | self.len_trend, self.interval_trend * self.hour_each_day) 85 | if not trend_indices: 86 | return None, None, None, None 87 | # (len_trend * self.input_window, ..., feature_dim) 88 | trend_sample = np.concatenate([data_sequence[i: j] for i, j in trend_indices], axis=0) 89 | 90 | if self.len_period > 0: 91 | period_indices = self._search_data(data_sequence.shape[0], label_start_idx, self.input_window, 92 | self.len_period, self.interval_period * self.hour_each_day) 93 | if not period_indices: 94 | return None, None, None, None 95 | # (len_period * self.input_window, ..., feature_dim) 96 | period_sample = np.concatenate([data_sequence[i: j] for i, j in period_indices], axis=0) 97 | 98 | if self.len_closeness > 0: 99 | closeness_indices = self._search_data(data_sequence.shape[0], label_start_idx, self.input_window, 100 | self.len_closeness, self.input_window / self.points_per_hour) 101 | if not closeness_indices: 102 | return None, None, None, None 103 | # (len_closeness * self.input_window, ..., feature_dim) 104 | closeness_sample = np.concatenate([data_sequence[i: j] for i, j in closeness_indices], axis=0) 105 | 106 | target = data_sequence[label_start_idx: label_start_idx + self.output_window] 107 | # (self.input_window, ..., feature_dim) 108 | return trend_sample, period_sample, closeness_sample, target 109 | 110 | def _generate_input_data(self, df): 111 | """ 112 | 根据全局参数len_closeness/len_period/len_trend切分输入,产生模型需要的输入 113 | 114 | Args: 115 | df(np.ndarray): 输入数据, shape: (len_time, ..., feature_dim) 116 | 117 | Returns: 118 | tuple: tuple contains: 119 | sources(np.ndarray): 模型输入数据, shape: (num_samples, Tw+Td+Th, ..., feature_dim) \n 120 | targets(np.ndarray): 模型输出数据, shape: (num_samples, Tp, ..., feature_dim) 121 | """ 122 | trend_samples, period_samples, closeness_samples, targets = [], [], [], [] 123 | flag = 0 124 | for idx in range(df.shape[0]): 125 | sample = self._get_sample_indices(df, idx) 126 | if (sample[0] is None) and (sample[1] is None) and (sample[2] is None): 127 | continue 128 | flag = 1 129 | trend_sample, period_sample, closeness_sample, target = sample 130 | if self.len_trend > 0: 131 | trend_sample = np.expand_dims(trend_sample, axis=0) # (1,Tw,N,F) 132 | trend_samples.append(trend_sample) 133 | if self.len_period > 0: 134 | period_sample = np.expand_dims(period_sample, axis=0) # (1,Td,N,F) 135 | period_samples.append(period_sample) 136 | if self.len_closeness > 0: 137 | closeness_sample = np.expand_dims(closeness_sample, axis=0) # (1,Th,N,F) 138 | closeness_samples.append(closeness_sample) 139 | target = np.expand_dims(target, axis=0) # (1,Tp,N,F) 140 | targets.append(target) 141 | if flag == 0: 142 | self._logger.warning('Parameter len_closeness/len_period/len_trend is too large ' 143 | 'for the time range of the data!') 144 | sys.exit() 145 | sources = [] 146 | if len(closeness_samples) > 0: 147 | closeness_samples = np.concatenate(closeness_samples, axis=0) # (num_samples,Th,N,F) 148 | sources.append(closeness_samples) 149 | self._logger.info('closeness: ' + str(closeness_samples.shape)) 150 | if len(period_samples) > 0: 151 | period_samples = np.concatenate(period_samples, axis=0) # (num_samples,Td,N,F) 152 | sources.append(period_samples) 153 | self._logger.info('period: ' + str(period_samples.shape)) 154 | if len(trend_samples) > 0: 155 | trend_samples = np.concatenate(trend_samples, axis=0) # (num_samples,Tw,N,F) 156 | sources.append(trend_samples) 157 | self._logger.info('trend: ' + str(trend_samples.shape)) 158 | sources = np.concatenate(sources, axis=1) # (num_samples,Tw+Td+Th,N,F) 159 | targets = np.concatenate(targets, axis=0) # (num_samples,Tp,N,F) 160 | return sources, targets 161 | 162 | def get_data_feature(self): 163 | """ 164 | 返回数据集特征,scaler是归一化方法,adj_mx是邻接矩阵,num_nodes是点的个数, 165 | feature_dim是输入数据的维度,output_dim是模型输出的维度, 166 | len_closeness/len_period/len_trend分别是三段数据的长度 167 | 168 | Returns: 169 | dict: 包含数据集的相关特征的字典 170 | """ 171 | return {"scaler": self.scaler, "adj_mx": self.adj_mx, "static": self.static, 172 | "ct_visit_mstd": self.ct_visit_mstd, "coordinate": self.coordinate, "num_nodes": self.num_nodes, 173 | "feature_dim": self.feature_dim, "output_dim": self.output_dim, 174 | "ext_dim": self.ext_dim, "len_closeness": self.len_closeness * self.input_window, 175 | "len_period": self.len_period * self.input_window, "len_trend": self.len_trend * self.input_window, 176 | "num_batches": self.num_batches} 177 | -------------------------------------------------------------------------------- /libcity/data/dataset/traffic_state_point_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from libcity.data.dataset import TrafficStateDataset 4 | 5 | 6 | class TrafficStatePointDataset(TrafficStateDataset): 7 | 8 | def __init__(self, config): 9 | super().__init__(config) 10 | self.cache_file_name = os.path.join('./libcity/cache/dataset_cache/', 11 | 'point_based_{}.npz'.format(self.parameters_str)) 12 | 13 | def _load_geo(self): 14 | """ 15 | 加载.geo文件,格式[geo_id, type, coordinates, properties(若干列)] 16 | """ 17 | super()._load_geo() 18 | 19 | def _load_rel(self): 20 | """ 21 | 加载.rel文件,格式[rel_id, type, origin_id, destination_id, properties(若干列)] 22 | 23 | Returns: 24 | np.ndarray: self.adj_mx, N*N的邻接矩阵 25 | """ 26 | super()._load_rel() 27 | 28 | def _load_dyna(self, filename): 29 | """ 30 | 加载.dyna文件,格式[dyna_id, type, time, entity_id, properties(若干列)] 31 | 其中全局参数`data_col`用于指定需要加载的数据的列,不设置则默认全部加载 32 | 33 | Args: 34 | filename(str): 数据文件名,不包含后缀 35 | 36 | Returns: 37 | np.ndarray: 数据数组, 3d-array (len_time, num_nodes, feature_dim) 38 | """ 39 | return super()._load_dyna_3d(filename) 40 | 41 | def _add_external_information(self, df, ext_data=None): 42 | """ 43 | 增加外部信息(一周中的星期几/day of week,一天中的某个时刻/time of day,外部数据) 44 | 45 | Args: 46 | df(np.ndarray): 交通状态数据多维数组, (len_time, num_nodes, feature_dim) 47 | ext_data(np.ndarray): 外部数据 48 | 49 | Returns: 50 | np.ndarray: 融合后的外部数据和交通状态数据, (len_time, num_nodes, feature_dim_plus) 51 | """ 52 | return super()._add_external_information_3d(df, ext_data) 53 | 54 | def get_data_feature(self): 55 | """ 56 | 返回数据集特征,scaler是归一化方法,adj_mx是邻接矩阵,num_nodes是点的个数, 57 | feature_dim是输入数据的维度,output_dim是模型输出的维度 58 | 59 | Returns: 60 | dict: 包含数据集的相关特征的字典 61 | """ 62 | return {"scaler": self.scaler, "adj_mx": self.adj_mx, "ext_dim": self.ext_dim, 63 | "num_nodes": self.num_nodes, "feature_dim": self.feature_dim, 64 | "output_dim": self.output_dim, "num_batches": self.num_batches} 65 | -------------------------------------------------------------------------------- /libcity/data/list_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class ListDataset(Dataset): 5 | def __init__(self, data): 6 | """ 7 | data: 必须是一个 list 8 | """ 9 | self.data = data 10 | 11 | def __getitem__(self, index): 12 | return self.data[index] 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | -------------------------------------------------------------------------------- /libcity/data/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | from torch.utils.data import DataLoader 4 | import copy 5 | 6 | from libcity.data.list_dataset import ListDataset 7 | from libcity.data.batch import Batch, BatchPAD 8 | 9 | 10 | def get_dataset(config): 11 | """ 12 | according the config['dataset_class'] to create the dataset 13 | 14 | Args: 15 | config(ConfigParser): config 16 | 17 | Returns: 18 | AbstractDataset: the loaded dataset 19 | """ 20 | try: 21 | return getattr(importlib.import_module('libcity.data.dataset'), 22 | config['dataset_class'])(config) 23 | except AttributeError: 24 | try: 25 | return getattr(importlib.import_module('libcity.data.dataset.dataset_subclass'), 26 | config['dataset_class'])(config) 27 | except AttributeError: 28 | raise AttributeError('dataset_class is not found') 29 | 30 | 31 | def generate_dataloader(train_data, eval_data, test_data, feature_name, 32 | batch_size, num_workers, shuffle=True, 33 | pad_with_last_sample=False): 34 | """ 35 | create dataloader(train/test/eval) 36 | 37 | Args: 38 | train_data(list of input): 训练数据,data 中每个元素是模型单次的输入,input 是一个 list,里面存放单次输入和 target 39 | eval_data(list of input): 验证数据,data 中每个元素是模型单次的输入,input 是一个 list,里面存放单次输入和 target 40 | test_data(list of input): 测试数据,data 中每个元素是模型单次的输入,input 是一个 list,里面存放单次输入和 target 41 | feature_name(dict): 描述上面 input 每个元素对应的特征名, 应保证len(feature_name) = len(input) 42 | batch_size(int): batch_size 43 | num_workers(int): num_workers 44 | shuffle(bool): shuffle 45 | pad_with_last_sample(bool): 对于若最后一个 batch 不满足 batch_size的情况,是否进行补齐(使用最后一个元素反复填充补齐)。 46 | 47 | Returns: 48 | tuple: tuple contains: 49 | train_dataloader: Dataloader composed of Batch (class) \n 50 | eval_dataloader: Dataloader composed of Batch (class) \n 51 | test_dataloader: Dataloader composed of Batch (class) 52 | """ 53 | if pad_with_last_sample: 54 | num_padding = (batch_size - (len(train_data) % batch_size)) % batch_size 55 | data_padding = np.repeat(train_data[-1:], num_padding, axis=0) 56 | train_data = np.concatenate([train_data, data_padding], axis=0) 57 | num_padding = (batch_size - (len(eval_data) % batch_size)) % batch_size 58 | data_padding = np.repeat(eval_data[-1:], num_padding, axis=0) 59 | eval_data = np.concatenate([eval_data, data_padding], axis=0) 60 | num_padding = (batch_size - (len(test_data) % batch_size)) % batch_size 61 | data_padding = np.repeat(test_data[-1:], num_padding, axis=0) 62 | test_data = np.concatenate([test_data, data_padding], axis=0) 63 | 64 | train_dataset = ListDataset(train_data) 65 | eval_dataset = ListDataset(eval_data) 66 | test_dataset = ListDataset(test_data) 67 | 68 | def collator(indices): 69 | batch = Batch(feature_name) 70 | for item in indices: 71 | batch.append(copy.deepcopy(item)) 72 | return batch 73 | 74 | train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, 75 | num_workers=num_workers, collate_fn=collator, 76 | shuffle=shuffle) 77 | eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=batch_size, 78 | num_workers=num_workers, collate_fn=collator, 79 | shuffle=shuffle) 80 | test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, 81 | num_workers=num_workers, collate_fn=collator, 82 | shuffle=False) 83 | return train_dataloader, eval_dataloader, test_dataloader 84 | 85 | 86 | def generate_dataloader_pad(train_data, eval_data, test_data, feature_name, 87 | batch_size, num_workers, pad_item=None, 88 | pad_max_len=None, shuffle=True): 89 | """ 90 | create dataloader(train/test/eval) 91 | 92 | Args: 93 | train_data(list of input): 训练数据,data 中每个元素是模型单次的输入,input 是一个 list,里面存放单次输入和 target 94 | eval_data(list of input): 验证数据,data 中每个元素是模型单次的输入,input 是一个 list,里面存放单次输入和 target 95 | test_data(list of input): 测试数据,data 中每个元素是模型单次的输入,input 是一个 list,里面存放单次输入和 target 96 | feature_name(dict): 描述上面 input 每个元素对应的特征名, 应保证len(feature_name) = len(input) 97 | batch_size(int): batch_size 98 | num_workers(int): num_workers 99 | pad_item(dict): 用于将不定长的特征补齐到一样的长度,每个特征名作为 key,若某特征名不在该 dict 内则不进行补齐。 100 | pad_max_len(dict): 用于截取不定长的特征,对于过长的特征进行剪切 101 | shuffle(bool): shuffle 102 | 103 | Returns: 104 | tuple: tuple contains: 105 | train_dataloader: Dataloader composed of Batch (class) \n 106 | eval_dataloader: Dataloader composed of Batch (class) \n 107 | test_dataloader: Dataloader composed of Batch (class) 108 | """ 109 | train_dataset = ListDataset(train_data) 110 | eval_dataset = ListDataset(eval_data) 111 | test_dataset = ListDataset(test_data) 112 | 113 | def collator(indices): 114 | batch = BatchPAD(feature_name, pad_item, pad_max_len) 115 | for item in indices: 116 | batch.append(copy.deepcopy(item)) 117 | batch.padding() 118 | return batch 119 | 120 | train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, 121 | num_workers=num_workers, collate_fn=collator, 122 | shuffle=shuffle) 123 | eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=batch_size, 124 | num_workers=num_workers, collate_fn=collator, 125 | shuffle=shuffle) 126 | test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, 127 | num_workers=num_workers, collate_fn=collator, 128 | shuffle=shuffle) 129 | return train_dataloader, eval_dataloader, test_dataloader 130 | -------------------------------------------------------------------------------- /libcity/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from libcity.evaluator.traffic_state_evaluator import TrafficStateEvaluator 2 | 3 | 4 | __all__ = [ 5 | "TrafficStateEvaluator", 6 | ] 7 | -------------------------------------------------------------------------------- /libcity/evaluator/abstract_evaluator.py: -------------------------------------------------------------------------------- 1 | class AbstractEvaluator(object): 2 | 3 | def __init__(self, config): 4 | raise NotImplementedError('evaluator not implemented') 5 | 6 | def collect(self, batch): 7 | """ 8 | 收集一 batch 的评估输入 9 | 10 | Args: 11 | batch(dict): 输入数据 12 | """ 13 | raise NotImplementedError('evaluator collect not implemented') 14 | 15 | def evaluate(self): 16 | """ 17 | 返回之前收集到的所有 batch 的评估结果 18 | """ 19 | raise NotImplementedError('evaluator evaluate not implemented') 20 | 21 | def save_result(self, save_path, filename=None): 22 | """ 23 | 将评估结果保存到 save_path 文件夹下的 filename 文件中 24 | 25 | Args: 26 | save_path: 保存路径 27 | filename: 保存文件名 28 | """ 29 | raise NotImplementedError('evaluator save_result not implemented') 30 | 31 | def clear(self): 32 | """ 33 | 清除之前收集到的 batch 的评估信息,适用于每次评估开始时进行一次清空,排除之前的评估输入的影响。 34 | """ 35 | raise NotImplementedError('evaluator clear not implemented') 36 | -------------------------------------------------------------------------------- /libcity/evaluator/eval_funcs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | # 均方误差(Mean Square Error) 6 | def mse(loc_pred, loc_true): 7 | assert len(loc_pred) == len(loc_true), 'MSE: 预测数据与真实数据大小不一致' 8 | return np.mean(sum(pow(loc_pred - loc_true, 2))) 9 | 10 | 11 | # 平均绝对误差(Mean Absolute Error) 12 | def mae(loc_pred, loc_true): 13 | assert len(loc_pred) == len(loc_true), 'MAE: 预测数据与真实数据大小不一致' 14 | return np.mean(sum(loc_pred - loc_true)) 15 | 16 | 17 | # 均方根误差(Root Mean Square Error) 18 | def rmse(loc_pred, loc_true): 19 | assert len(loc_pred) == len(loc_true), 'RMSE: 预测数据与真实数据大小不一致' 20 | return np.sqrt(np.mean(sum(pow(loc_pred - loc_true, 2)))) 21 | 22 | 23 | # 平均绝对百分比误差(Mean Absolute Percentage Error) 24 | def mape(loc_pred, loc_true): 25 | assert len(loc_pred) == len(loc_true), 'MAPE: 预测数据与真实数据大小不一致' 26 | assert 0 not in loc_true, "MAPE: 真实数据有0,该公式不适用" 27 | return np.mean(abs(loc_pred - loc_true) / loc_true) 28 | 29 | 30 | # 平均绝对和相对误差(Mean Absolute Relative Error) 31 | def mare(loc_pred, loc_true): 32 | assert len(loc_pred) == len(loc_true), "MARE:预测数据与真实数据大小不一致" 33 | assert np.sum(loc_true) != 0, "MARE:真实位置全为0,该公式不适用" 34 | return np.sum(np.abs(loc_pred - loc_true)) / np.sum(loc_true) 35 | 36 | 37 | # 对称平均绝对百分比误差(Symmetric Mean Absolute Percentage Error) 38 | def smape(loc_pred, loc_true): 39 | assert len(loc_pred) == len(loc_true), 'SMAPE: 预测数据与真实数据大小不一致' 40 | assert 0 in (loc_pred + loc_true), "SMAPE: 预测数据与真实数据有0,该公式不适用" 41 | return 2.0 * np.mean(np.abs(loc_pred - loc_true) / (np.abs(loc_pred) + 42 | np.abs(loc_true))) 43 | 44 | 45 | # 对比真实位置与预测位置获得预测准确率 46 | def acc(loc_pred, loc_true): 47 | assert len(loc_pred) == len(loc_true), "accuracy: 预测数据与真实数据大小不一致" 48 | loc_diff = loc_pred - loc_true 49 | loc_diff[loc_diff != 0] = 1 50 | return loc_diff, np.mean(loc_diff == 0) 51 | 52 | 53 | def top_k(loc_pred, loc_true, topk): 54 | """ 55 | count the hit numbers of loc_true in topK of loc_pred, used to calculate Precision, Recall and F1-score, 56 | calculate the reciprocal rank, used to calcualte MRR, 57 | calculate the sum of DCG@K of the batch, used to calculate NDCG 58 | 59 | Args: 60 | loc_pred: (batch_size * output_dim) 61 | loc_true: (batch_size * 1) 62 | topk: 63 | 64 | Returns: 65 | tuple: tuple contains: 66 | hit (int): the hit numbers \n 67 | rank (float): the sum of the reciprocal rank of input batch \n 68 | dcg (float): dcg 69 | """ 70 | assert topk > 0, "top-k ACC评估方法:k值应不小于1" 71 | loc_pred = torch.FloatTensor(loc_pred) 72 | val, index = torch.topk(loc_pred, topk, 1) 73 | index = index.numpy() 74 | hit = 0 75 | rank = 0.0 76 | dcg = 0.0 77 | for i, p in enumerate(index): 78 | target = loc_true[i] 79 | if target in p: 80 | hit += 1 81 | rank_list = list(p) 82 | rank_index = rank_list.index(target) 83 | # rank_index is start from 0, so need plus 1 84 | rank += 1.0 / (rank_index + 1) 85 | dcg += 1.0 / np.log2(rank_index + 2) 86 | return hit, rank, dcg 87 | 88 | def Precision_torch(preds, labels, topk): 89 | precision = [] 90 | for i in range(preds.shape[0]): 91 | label = labels[i] 92 | pred = preds[i] 93 | accident_grids = label > 0 94 | sorted, _ = torch.sort(pred.flatten(), descending=True) 95 | threshold = sorted[topk - 1] 96 | pred_grids = pred >= threshold 97 | matched = pred_grids & accident_grids 98 | precision.append(torch.sum(matched.flatten()).item() / topk) 99 | return sum(precision) / len(precision) 100 | 101 | def Recall_torch(preds, labels, topk): 102 | recall = [] 103 | for i in range(preds.shape[0]): 104 | label = labels[i] 105 | pred = preds[i] 106 | accident_grids = label > 0 107 | sorted, _ = torch.sort(pred.flatten(), descending=True) 108 | threshold = sorted[topk - 1] 109 | pred_grids = pred >= threshold 110 | matched = pred_grids & accident_grids 111 | if torch.sum(accident_grids).item() != 0: 112 | recall.append(torch.sum(matched.flatten()).item() / torch.sum(accident_grids.flatten()).item()) 113 | return sum(recall) / len(recall) 114 | 115 | def F1_Score_torch(preds, labels, topk): 116 | precision = Precision_torch(preds, labels, topk) 117 | recall = Recall_torch(preds, labels, topk) 118 | return 2 * precision * recall / (precision + recall) 119 | 120 | 121 | 122 | def MAP_torch(preds, labels, topk): 123 | ap = [] 124 | for i in range(preds.shape[0]): 125 | label = labels[i].flatten() 126 | pred = preds[i].flatten() 127 | accident_grids = label > 0 128 | sorted, rank = torch.sort(pred, descending=True) 129 | rank = rank[:topk] 130 | if topk != 0: 131 | threshold = sorted[topk - 1] 132 | else: 133 | threshold = 0 134 | label = label != 0 135 | pred = pred >= threshold 136 | matched = pred & label 137 | match_num = 0 138 | precision_sum = 0 139 | for i in range(rank.shape[0]): 140 | if matched[rank[i]]: 141 | match_num += 1 142 | precision_sum += match_num / (i + 1) 143 | if rank.shape[0] != 0: 144 | ap.append(precision_sum / rank.shape[0]) 145 | return sum(ap) / len(ap) 146 | 147 | 148 | def PCC_torch(preds, labels, topk): 149 | pcc = [] 150 | for i in range(preds.shape[0]): 151 | label = labels[i].flatten() 152 | pred = preds[i].flatten() 153 | sorted, rank = torch.sort(pred, descending=True) 154 | pred = sorted[:topk] 155 | rank = rank[:topk] 156 | sorted_label = torch.zeros(topk) 157 | for i in range(topk): 158 | sorted_label[i] = label[rank[i]] 159 | label = sorted_label 160 | label_average = torch.sum(label) / (label.shape[0]) 161 | pred_average = torch.sum(pred) / (pred.shape[0]) 162 | if torch.sqrt(torch.sum((label - label_average) * (label - label_average))) * torch.sqrt( 163 | torch.sum((pred - pred_average) * (pred - pred_average))) != 0: 164 | pcc.append((torch.sum((label - label_average) * (pred - pred_average)) / ( 165 | torch.sqrt(torch.sum((label - label_average) * (label - label_average))) * torch.sqrt( 166 | torch.sum((pred - pred_average) * (pred - pred_average))))).item()) 167 | return sum(pcc) / len(pcc) 168 | -------------------------------------------------------------------------------- /libcity/evaluator/traffic_state_evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import datetime 4 | import pandas as pd 5 | from libcity.utils import ensure_dir 6 | from libcity.model import loss 7 | from logging import getLogger 8 | from libcity.evaluator.abstract_evaluator import AbstractEvaluator 9 | 10 | 11 | class TrafficStateEvaluator(AbstractEvaluator): 12 | 13 | def __init__(self, config): 14 | self.metrics = config.get('metrics', ['MAE']) # 评估指标, 是一个 list 15 | self.allowed_metrics = ['MAE', 'MSE', 'RMSE', 'MAPE', 'masked_MAE', 16 | 'masked_MSE', 'masked_RMSE', 'masked_MAPE', 'R2', 'EVAR'] 17 | self.save_modes = config.get('save_mode', ['csv', 'json']) 18 | self.mode = config.get('evaluator_mode', 'single') # or average 19 | self.config = config 20 | self.min_s = config.get('min_s', 1e-4) 21 | self.len_timeslots = 0 22 | self.result = {} # 每一种指标的结果 23 | self.intermediate_result = {} # 每一种指标每一个batch的结果 24 | self._check_config() 25 | self._logger = getLogger() 26 | 27 | def _check_config(self): 28 | if not isinstance(self.metrics, list): 29 | raise TypeError('Evaluator type is not list') 30 | for metric in self.metrics: 31 | if metric not in self.allowed_metrics: 32 | raise ValueError('the metric {} is not allowed in TrafficStateEvaluator'.format(str(metric))) 33 | 34 | def collect(self, batch): 35 | """ 36 | 收集一 batch 的评估输入 37 | 38 | Args: 39 | batch(dict): 输入数据,字典类型,包含两个Key:(y_true, y_pred): 40 | batch['y_true']: (num_samples/batch_size, timeslots, ..., feature_dim) 41 | batch['y_pred']: (num_samples/batch_size, timeslots, ..., feature_dim) 42 | """ 43 | if not isinstance(batch, dict): 44 | raise TypeError('evaluator.collect input is not a dict of user') 45 | y_true = batch['y_true'] # tensor 46 | y_pred = batch['y_pred'] # tensor 47 | if y_true.shape != y_pred.shape: 48 | raise ValueError("batch['y_true'].shape is not equal to batch['y_pred'].shape") 49 | self.len_timeslots = y_true.shape[1] 50 | for i in range(1, self.len_timeslots + 1): 51 | for metric in self.metrics: 52 | if metric + '@' + str(i) not in self.intermediate_result: 53 | self.intermediate_result[metric + '@' + str(i)] = [] 54 | if self.mode.lower() == 'average': # 前i个时间步的平均loss 55 | for i in range(1, self.len_timeslots + 1): 56 | for metric in self.metrics: 57 | if metric == 'masked_MAE': 58 | self.intermediate_result[metric + '@' + str(i)].append( 59 | loss.masked_mae_torch(y_pred[:, :i], y_true[:, :i], 0, min_s=self.min_s).item()) 60 | elif metric == 'masked_MSE': 61 | self.intermediate_result[metric + '@' + str(i)].append( 62 | loss.masked_mse_torch(y_pred[:, :i], y_true[:, :i], 0, min_s=self.min_s).item()) 63 | elif metric == 'masked_RMSE': 64 | self.intermediate_result[metric + '@' + str(i)].append( 65 | loss.masked_rmse_torch(y_pred[:, :i], y_true[:, :i], 0, min_s=self.min_s).item()) 66 | elif metric == 'masked_MAPE': 67 | self.intermediate_result[metric + '@' + str(i)].append( 68 | loss.masked_mape_torch(y_pred[:, :i], y_true[:, :i], 0, min_s=self.min_s).item()) 69 | elif metric == 'MAE': 70 | self.intermediate_result[metric + '@' + str(i)].append( 71 | loss.masked_mae_torch(y_pred[:, :i], y_true[:, :i]).item()) 72 | elif metric == 'MSE': 73 | self.intermediate_result[metric + '@' + str(i)].append( 74 | loss.masked_mse_torch(y_pred[:, :i], y_true[:, :i]).item()) 75 | elif metric == 'RMSE': 76 | self.intermediate_result[metric + '@' + str(i)].append( 77 | loss.masked_rmse_torch(y_pred[:, :i], y_true[:, :i]).item()) 78 | elif metric == 'MAPE': 79 | self.intermediate_result[metric + '@' + str(i)].append( 80 | loss.masked_mape_torch(y_pred[:, :i], y_true[:, :i]).item()) 81 | elif metric == 'R2': 82 | self.intermediate_result[metric + '@' + str(i)].append( 83 | loss.r2_score_torch(y_pred[:, :i], y_true[:, :i]).item()) 84 | elif metric == 'EVAR': 85 | self.intermediate_result[metric + '@' + str(i)].append( 86 | loss.explained_variance_score_torch(y_pred[:, :i], y_true[:, :i]).item()) 87 | elif self.mode.lower() == 'single': # 第i个时间步的loss 88 | for i in range(1, self.len_timeslots + 1): 89 | for metric in self.metrics: 90 | if metric == 'masked_MAE': 91 | self.intermediate_result[metric + '@' + str(i)].append( 92 | loss.masked_mae_torch(y_pred[:, i - 1], y_true[:, i - 1], 0, min_s=self.min_s).item()) 93 | elif metric == 'masked_MSE': 94 | self.intermediate_result[metric + '@' + str(i)].append( 95 | loss.masked_mse_torch(y_pred[:, i - 1], y_true[:, i - 1], 0, min_s=self.min_s).item()) 96 | elif metric == 'masked_RMSE': 97 | self.intermediate_result[metric + '@' + str(i)].append( 98 | loss.masked_rmse_torch(y_pred[:, i - 1], y_true[:, i - 1], 0, min_s=self.min_s).item()) 99 | elif metric == 'masked_MAPE': 100 | self.intermediate_result[metric + '@' + str(i)].append( 101 | loss.masked_mape_torch(y_pred[:, i - 1], y_true[:, i - 1], 0, min_s=self.min_s).item()) 102 | elif metric == 'MAE': 103 | self.intermediate_result[metric + '@' + str(i)].append( 104 | loss.masked_mae_torch(y_pred[:, i - 1], y_true[:, i - 1]).item()) 105 | elif metric == 'MSE': 106 | self.intermediate_result[metric + '@' + str(i)].append( 107 | loss.masked_mse_torch(y_pred[:, i - 1], y_true[:, i - 1]).item()) 108 | elif metric == 'RMSE': 109 | self.intermediate_result[metric + '@' + str(i)].append( 110 | loss.masked_rmse_torch(y_pred[:, i - 1], y_true[:, i - 1]).item()) 111 | elif metric == 'MAPE': 112 | self.intermediate_result[metric + '@' + str(i)].append( 113 | loss.masked_mape_torch(y_pred[:, i - 1], y_true[:, i - 1]).item()) 114 | elif metric == 'R2': 115 | self.intermediate_result[metric + '@' + str(i)].append( 116 | loss.r2_score_torch(y_pred[:, i - 1], y_true[:, i - 1]).item()) 117 | elif metric == 'EVAR': 118 | self.intermediate_result[metric + '@' + str(i)].append( 119 | loss.explained_variance_score_torch(y_pred[:, i - 1], y_true[:, i - 1]).item()) 120 | else: 121 | raise ValueError('Error parameter evaluator_mode={}, please set `single` or `average`.'.format(self.mode)) 122 | 123 | def evaluate(self): 124 | """ 125 | 返回之前收集到的所有 batch 的评估结果 126 | """ 127 | for i in range(1, self.len_timeslots + 1): 128 | for metric in self.metrics: 129 | self.result[metric + '@' + str(i)] = sum(self.intermediate_result[metric + '@' + str(i)]) / \ 130 | len(self.intermediate_result[metric + '@' + str(i)]) 131 | return self.result 132 | 133 | def save_result(self, save_path, filename=None): 134 | """ 135 | 将评估结果保存到 save_path 文件夹下的 filename 文件中 136 | 137 | Args: 138 | save_path: 保存路径 139 | filename: 保存文件名 140 | """ 141 | self._logger.info('Note that you select the {} mode to evaluate!'.format(self.mode)) 142 | self.evaluate() 143 | ensure_dir(save_path) 144 | if filename is None: # 使用时间戳 145 | filename = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') + '_' + \ 146 | self.config['model'] + '_' + self.config['dataset'] 147 | 148 | if 'json' in self.save_modes: 149 | self._logger.info('Evaluate result is ' + json.dumps(self.result)) 150 | with open(os.path.join(save_path, '{}.json'.format(filename)), 'w') as f: 151 | json.dump(self.result, f) 152 | self._logger.info('Evaluate result is saved at ' + 153 | os.path.join(save_path, '{}.json'.format(filename))) 154 | 155 | dataframe = {} 156 | if 'csv' in self.save_modes: 157 | for metric in self.metrics: 158 | dataframe[metric] = [] 159 | for i in range(1, self.len_timeslots + 1): 160 | for metric in self.metrics: 161 | dataframe[metric].append(self.result[metric + '@' + str(i)]) 162 | dataframe = pd.DataFrame(dataframe, index=range(1, self.len_timeslots + 1)) 163 | dataframe.to_csv(os.path.join(save_path, '{}.csv'.format(filename)), index=False) 164 | self._logger.info('Evaluate result is saved at ' + os.path.join(save_path, '{}.csv'.format(filename))) 165 | self._logger.info("\n" + str(dataframe[['MAE', 'masked_MAE', 'masked_MAPE', 'masked_RMSE']])) 166 | self._logger.info("\n" + str(dataframe[['MAE', 'masked_MAE', 'masked_MAPE', 'masked_RMSE']].mean())) 167 | return dataframe 168 | 169 | def clear(self): 170 | """ 171 | 清除之前收集到的 batch 的评估信息,适用于每次评估开始时进行一次清空,排除之前的评估输入的影响。 172 | """ 173 | self.result = {} 174 | self.intermediate_result = {} 175 | -------------------------------------------------------------------------------- /libcity/evaluator/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from heapq import nlargest 3 | import pandas as pd 4 | from libcity.model.loss import * 5 | 6 | 7 | def output(method, value, field): 8 | """ 9 | Args: 10 | method: 评估方法 11 | value: 对应评估方法的评估结果值 12 | field: 评估的范围, 对一条轨迹或是整个模型 13 | """ 14 | if method == 'ACC': 15 | if field == 'model': 16 | print('---- 该模型在 {} 评估方法下 avg_acc={:.3f} ----'.format(method, 17 | value)) 18 | else: 19 | print('{} avg_acc={:.3f}'.format(method, value)) 20 | elif method in ['MSE', 'RMSE', 'MAE', 'MAPE', 'MARE', 'SMAPE']: 21 | if field == 'model': 22 | print('---- 该模型在 {} 评估方法下 avg_loss={:.3f} ----'.format(method, 23 | value)) 24 | else: 25 | print('{} avg_loss={:.3f}'.format(method, value)) 26 | else: 27 | if field == 'model': 28 | print('---- 该模型在 {} 评估方法下 avg_acc={:.3f} ----'.format(method, 29 | value)) 30 | else: 31 | print('{} avg_acc={:.3f}'.format(method, value)) 32 | 33 | 34 | def transfer_data(data, model, maxk): 35 | """ 36 | Here we transform specific data types to standard input type 37 | """ 38 | if type(data) == str: 39 | data = json.loads(data) 40 | assert type(data) == dict, "待评估数据的类型/格式不合法" 41 | if model == 'DeepMove': 42 | user_idx = data.keys() 43 | for user_id in user_idx: 44 | trace_idx = data[user_id].keys() 45 | for trace_id in trace_idx: 46 | trace = data[user_id][trace_id] 47 | loc_pred = trace['loc_pred'] 48 | new_loc_pred = [] 49 | for t_list in loc_pred: 50 | new_loc_pred.append(sort_confidence_ids(t_list, maxk)) 51 | data[user_id][trace_id]['loc_pred'] = new_loc_pred 52 | return data 53 | 54 | 55 | def sort_confidence_ids(confidence_list, threshold): 56 | """ 57 | Here we convert the prediction results of the DeepMove model 58 | DeepMove model output: confidence of all locations 59 | Evaluate model input: location ids based on confidence 60 | :param threshold: maxK 61 | :param confidence_list: 62 | :return: ids_list 63 | """ 64 | """sorted_list = sorted(confidence_list, reverse=True) 65 | mark_list = [0 for i in confidence_list] 66 | ids_list = [] 67 | for item in sorted_list: 68 | for i in range(len(confidence_list)): 69 | if confidence_list[i] == item and mark_list[i] == 0: 70 | mark_list[i] = 1 71 | ids_list.append(i) 72 | break 73 | if len(ids_list) == threshold: 74 | break 75 | return ids_list""" 76 | max_score_with_id = nlargest( 77 | threshold, enumerate(confidence_list), lambda x: x[1]) 78 | return list(map(lambda x: x[0], max_score_with_id)) 79 | 80 | 81 | def evaluate_model(y_pred, y_true, metrics, mode='single', path='metrics.csv'): 82 | """ 83 | 交通状态预测评估函数 84 | :param y_pred: (num_samples/batch_size, timeslots, ..., feature_dim) 85 | :param y_true: (num_samples/batch_size, timeslots, ..., feature_dim) 86 | :param metrics: 评估指标 87 | :param mode: 单步or多步平均 88 | :param path: 保存结果 89 | :return: 90 | """ 91 | if y_true.shape != y_pred.shape: 92 | raise ValueError("y_true.shape is not equal to y_pred.shape") 93 | len_timeslots = y_true.shape[1] 94 | if isinstance(y_pred, np.ndarray): 95 | y_pred = torch.FloatTensor(y_pred) 96 | if isinstance(y_true, np.ndarray): 97 | y_true = torch.FloatTensor(y_true) 98 | assert isinstance(y_pred, torch.Tensor) 99 | assert isinstance(y_true, torch.Tensor) 100 | 101 | df = [] 102 | for i in range(1, len_timeslots + 1): 103 | line = {} 104 | for metric in metrics: 105 | if mode.lower() == 'single': 106 | if metric == 'masked_MAE': 107 | line[metric] = masked_mae_torch(y_pred[:, i - 1], y_true[:, i - 1], 0).item() 108 | elif metric == 'masked_MSE': 109 | line[metric] = masked_mse_torch(y_pred[:, i - 1], y_true[:, i - 1], 0).item() 110 | elif metric == 'masked_RMSE': 111 | line[metric] = masked_rmse_torch(y_pred[:, i - 1], y_true[:, i - 1], 0).item() 112 | elif metric == 'masked_MAPE': 113 | line[metric] = masked_mape_torch(y_pred[:, i - 1], y_true[:, i - 1], 0).item() 114 | elif metric == 'MAE': 115 | line[metric] = masked_mae_torch(y_pred[:, i - 1], y_true[:, i - 1]).item() 116 | elif metric == 'MSE': 117 | line[metric] = masked_mse_torch(y_pred[:, i - 1], y_true[:, i - 1]).item() 118 | elif metric == 'RMSE': 119 | line[metric] = masked_rmse_torch(y_pred[:, i - 1], y_true[:, i - 1]).item() 120 | elif metric == 'MAPE': 121 | line[metric] = masked_mape_torch(y_pred[:, i - 1], y_true[:, i - 1]).item() 122 | elif metric == 'R2': 123 | line[metric] = r2_score_torch(y_pred[:, i - 1], y_true[:, i - 1]).item() 124 | elif metric == 'EVAR': 125 | line[metric] = explained_variance_score_torch(y_pred[:, i - 1], y_true[:, i - 1]).item() 126 | else: 127 | raise ValueError('Error parameter mode={}, please set `single` or `average`.'.format(mode)) 128 | elif mode.lower() == 'average': 129 | if metric == 'masked_MAE': 130 | line[metric] = masked_mae_torch(y_pred[:, :i], y_true[:, :i], 0).item() 131 | elif metric == 'masked_MSE': 132 | line[metric] = masked_mse_torch(y_pred[:, :i], y_true[:, :i], 0).item() 133 | elif metric == 'masked_RMSE': 134 | line[metric] = masked_rmse_torch(y_pred[:, :i], y_true[:, :i], 0).item() 135 | elif metric == 'masked_MAPE': 136 | line[metric] = masked_mape_torch(y_pred[:, :i], y_true[:, :i], 0).item() 137 | elif metric == 'MAE': 138 | line[metric] = masked_mae_torch(y_pred[:, :i], y_true[:, :i]).item() 139 | elif metric == 'MSE': 140 | line[metric] = masked_mse_torch(y_pred[:, :i], y_true[:, :i]).item() 141 | elif metric == 'RMSE': 142 | line[metric] = masked_rmse_torch(y_pred[:, :i], y_true[:, :i]).item() 143 | elif metric == 'MAPE': 144 | line[metric] = masked_mape_torch(y_pred[:, :i], y_true[:, :i]).item() 145 | elif metric == 'R2': 146 | line[metric] = r2_score_torch(y_pred[:, :i], y_true[:, :i]).item() 147 | elif metric == 'EVAR': 148 | line[metric] = explained_variance_score_torch(y_pred[:, :i], y_true[:, :i]).item() 149 | else: 150 | raise ValueError('Error parameter metric={}!'.format(metric)) 151 | else: 152 | raise ValueError('Error parameter evaluator_mode={}, please set `single` or `average`.'.format(mode)) 153 | df.append(line) 154 | df = pd.DataFrame(df, columns=metrics) 155 | print(df) 156 | df.to_csv(path) 157 | return df 158 | -------------------------------------------------------------------------------- /libcity/executor/__init__.py: -------------------------------------------------------------------------------- 1 | from libcity.executor.hyper_tuning import HyperTuning 2 | 3 | from libcity.executor.traffic_state_executor import TrafficStateExecutor 4 | 5 | __all__ = [ 6 | "TrafficStateExecutor", 7 | "HyperTuning", 8 | ] 9 | -------------------------------------------------------------------------------- /libcity/executor/abstract_executor.py: -------------------------------------------------------------------------------- 1 | class AbstractExecutor(object): 2 | 3 | def __init__(self, config, model, data_feature): 4 | raise NotImplementedError("Executor not implemented") 5 | 6 | def train(self, train_dataloader, eval_dataloader): 7 | """ 8 | use data to train model with config 9 | 10 | Args: 11 | train_dataloader(torch.Dataloader): Dataloader 12 | eval_dataloader(torch.Dataloader): Dataloader 13 | """ 14 | raise NotImplementedError("Executor train not implemented") 15 | 16 | def evaluate(self, test_dataloader): 17 | """ 18 | use model to test data 19 | 20 | Args: 21 | test_dataloader(torch.Dataloader): Dataloader 22 | """ 23 | raise NotImplementedError("Executor evaluate not implemented") 24 | 25 | def load_model(self, cache_name): 26 | """ 27 | 加载对应模型的 cache 28 | 29 | Args: 30 | cache_name(str): 保存的文件名 31 | """ 32 | raise NotImplementedError("Executor load cache not implemented") 33 | 34 | def save_model(self, cache_name): 35 | """ 36 | 将当前的模型保存到文件 37 | 38 | Args: 39 | cache_name(str): 保存的文件名 40 | """ 41 | raise NotImplementedError("Executor save cache not implemented") 42 | -------------------------------------------------------------------------------- /libcity/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SonghuaHu-UMD/MultiSTGraph/96fab1754905336cca08cc83200b6e81f3aa6c43/libcity/model/__init__.py -------------------------------------------------------------------------------- /libcity/model/abstract_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class AbstractModel(nn.Module): 5 | 6 | def __init__(self, config, data_feature): 7 | nn.Module.__init__(self) 8 | 9 | def predict(self, batch): 10 | """ 11 | Args: 12 | batch (Batch): a batch of input 13 | 14 | Returns: 15 | torch.tensor: predict result of this batch 16 | """ 17 | 18 | def calculate_loss(self, batch): 19 | """ 20 | Args: 21 | batch (Batch): a batch of input 22 | 23 | Returns: 24 | torch.tensor: return training loss 25 | """ 26 | -------------------------------------------------------------------------------- /libcity/model/abstract_traffic_state_model.py: -------------------------------------------------------------------------------- 1 | from libcity.model.abstract_model import AbstractModel 2 | 3 | 4 | class AbstractTrafficStateModel(AbstractModel): 5 | 6 | def __init__(self, config, data_feature): 7 | self.data_feature = data_feature 8 | super().__init__(config, data_feature) 9 | 10 | def predict(self, batch): 11 | """ 12 | 输入一个batch的数据,返回对应的预测值,一般应该是**多步预测**的结果,一般会调用nn.Moudle的forward()方法 13 | 14 | Args: 15 | batch (Batch): a batch of input 16 | 17 | Returns: 18 | torch.tensor: predict result of this batch 19 | """ 20 | 21 | def calculate_loss(self, batch): 22 | """ 23 | 输入一个batch的数据,返回训练过程的loss,也就是需要定义一个loss函数 24 | 25 | Args: 26 | batch (Batch): a batch of input 27 | 28 | Returns: 29 | torch.tensor: return training loss 30 | """ 31 | -------------------------------------------------------------------------------- /libcity/model/abstract_traffic_tradition_model.py: -------------------------------------------------------------------------------- 1 | class AbstractTraditionModel: 2 | 3 | def __init__(self, config, data_feature): 4 | self.data_feature = data_feature 5 | 6 | def run(self, data): 7 | """ 8 | Args: 9 | data : input of tradition model 10 | 11 | Returns: 12 | output of tradition model 13 | """ 14 | -------------------------------------------------------------------------------- /libcity/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import r2_score, explained_variance_score 4 | 5 | 6 | def masked_mae_loss(y_pred, y_true): 7 | mask = (y_true != 0).float() 8 | mask /= mask.mean() 9 | loss = torch.abs(y_pred - y_true) 10 | loss = loss * mask 11 | # trick for nans: 12 | # https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3 13 | loss[loss != loss] = 0 14 | return loss.mean() 15 | 16 | 17 | def masked_mae_torch(preds, labels, null_val=np.nan, min_s=1e-4): 18 | labels[torch.abs(labels) < min_s] = 0 19 | if np.isnan(null_val): 20 | mask = ~torch.isnan(labels) 21 | else: 22 | mask = labels.ne(null_val) 23 | mask = mask.float() 24 | mask /= torch.mean(mask) 25 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 26 | loss = torch.abs(torch.sub(preds, labels)) 27 | loss = loss * mask 28 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 29 | return torch.mean(loss) 30 | 31 | 32 | def log_cosh_loss(preds, labels): 33 | loss = torch.log(torch.cosh(preds - labels)) 34 | return torch.mean(loss) 35 | 36 | 37 | def huber_loss(preds, labels, delta=1.0): 38 | residual = torch.abs(preds - labels) 39 | condition = torch.le(residual, delta) 40 | small_res = 0.5 * torch.square(residual) 41 | large_res = delta * residual - 0.5 * delta * delta 42 | return torch.mean(torch.where(condition, small_res, large_res)) 43 | # lo = torch.nn.SmoothL1Loss() 44 | # return lo(preds, labels) 45 | 46 | 47 | def quantile_loss(preds, labels, delta=0.25): 48 | condition = torch.ge(labels, preds) 49 | large_res = delta * (labels - preds) 50 | small_res = (1 - delta) * (preds - labels) 51 | return torch.mean(torch.where(condition, large_res, small_res)) 52 | 53 | 54 | def masked_mape_torch(preds, labels, null_val=np.nan, eps=0, min_s=1e-4): 55 | labels[torch.abs(labels) < min_s] = 0 56 | if np.isnan(null_val) and eps != 0: 57 | loss = torch.abs((preds - labels) / (labels + eps)) 58 | return torch.mean(loss) 59 | if np.isnan(null_val): 60 | mask = ~torch.isnan(labels) 61 | else: 62 | mask = labels.ne(null_val) 63 | mask = mask.float() 64 | mask /= torch.mean(mask) 65 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 66 | loss = torch.abs((preds - labels) / labels) 67 | loss = loss * mask 68 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 69 | return torch.mean(loss) 70 | 71 | 72 | def masked_mse_torch(preds, labels, null_val=np.nan, min_s=1e-4): 73 | labels[torch.abs(labels) < min_s] = 0 74 | if np.isnan(null_val): 75 | mask = ~torch.isnan(labels) 76 | else: 77 | mask = labels.ne(null_val) 78 | mask = mask.float() 79 | mask /= torch.mean(mask) 80 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 81 | loss = torch.square(torch.sub(preds, labels)) 82 | loss = loss * mask 83 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 84 | return torch.mean(loss) 85 | 86 | 87 | def masked_rmse_torch(preds, labels, null_val=np.nan, min_s=1e-4): 88 | labels[torch.abs(labels) < min_s] = 0 89 | return torch.sqrt(masked_mse_torch(preds=preds, labels=labels, 90 | null_val=null_val)) 91 | 92 | 93 | def r2_score_torch(preds, labels): 94 | preds = preds.cpu().flatten() 95 | labels = labels.cpu().flatten() 96 | return r2_score(labels, preds) 97 | 98 | 99 | def explained_variance_score_torch(preds, labels): 100 | preds = preds.cpu().flatten() 101 | labels = labels.cpu().flatten() 102 | return explained_variance_score(labels, preds) 103 | 104 | 105 | def masked_rmse_np(preds, labels, null_val=np.nan): 106 | return np.sqrt(masked_mse_np(preds=preds, labels=labels, 107 | null_val=null_val)) 108 | 109 | 110 | def masked_mse_np(preds, labels, null_val=np.nan): 111 | with np.errstate(divide='ignore', invalid='ignore'): 112 | if np.isnan(null_val): 113 | mask = ~np.isnan(labels) 114 | else: 115 | mask = np.not_equal(labels, null_val) 116 | mask = mask.astype('float32') 117 | mask /= np.mean(mask) 118 | rmse = np.square(np.subtract(preds, labels)).astype('float32') 119 | rmse = np.nan_to_num(rmse * mask) 120 | return np.mean(rmse) 121 | 122 | 123 | def masked_mae_np(preds, labels, null_val=np.nan): 124 | with np.errstate(divide='ignore', invalid='ignore'): 125 | if np.isnan(null_val): 126 | mask = ~np.isnan(labels) 127 | else: 128 | mask = np.not_equal(labels, null_val) 129 | mask = mask.astype('float32') 130 | mask /= np.mean(mask) 131 | mae = np.abs(np.subtract(preds, labels)).astype('float32') 132 | mae = np.nan_to_num(mae * mask) 133 | return np.mean(mae) 134 | 135 | 136 | def masked_mape_np(preds, labels, null_val=np.nan): 137 | with np.errstate(divide='ignore', invalid='ignore'): 138 | if np.isnan(null_val): 139 | mask = ~np.isnan(labels) 140 | else: 141 | mask = np.not_equal(labels, null_val) 142 | mask = mask.astype('float32') 143 | mask /= np.mean(mask) 144 | mape = np.abs(np.divide(np.subtract( 145 | preds, labels).astype('float32'), labels)) 146 | mape = np.nan_to_num(mask * mape) 147 | return np.mean(mape) 148 | 149 | 150 | def r2_score_np(preds, labels): 151 | preds = preds.flatten() 152 | labels = labels.flatten() 153 | return r2_score(labels, preds) 154 | 155 | 156 | def explained_variance_score_np(preds, labels): 157 | preds = preds.flatten() 158 | labels = labels.flatten() 159 | return explained_variance_score(labels, preds) 160 | -------------------------------------------------------------------------------- /libcity/model/traffic_flow_prediction/__init__.py: -------------------------------------------------------------------------------- 1 | from libcity.model.traffic_flow_prediction.MultiATGCN import MultiATGCN 2 | 3 | __all__ = [ 4 | 5 | "MultiATGCN", 6 | ] 7 | -------------------------------------------------------------------------------- /libcity/model/utils.py: -------------------------------------------------------------------------------- 1 | import scipy.sparse as sp 2 | from scipy.sparse import linalg 3 | import numpy as np 4 | import torch 5 | 6 | 7 | # def build_sparse_matrix(device, lap): 8 | # lap = lap.tocoo() 9 | # indices = np.column_stack((lap.row, lap.col)) 10 | # # this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L) 11 | # indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))] 12 | # lap = torch.sparse_coo_tensor(indices.T, lap.data, lap.shape, device=device) 13 | # return lap.to(torch.float32) 14 | 15 | 16 | def build_sparse_matrix(device, lap): 17 | """ 18 | 构建稀疏矩阵(tensor) 19 | 20 | Args: 21 | device: 22 | lap: 拉普拉斯 23 | 24 | Returns: 25 | 26 | """ 27 | shape = lap.shape 28 | i = torch.LongTensor(np.vstack((lap.row, lap.col)).astype(int)) 29 | v = torch.FloatTensor(lap.data) 30 | return torch.sparse.FloatTensor(i, v, torch.Size(shape)).to(device) 31 | 32 | 33 | def get_cheb_polynomial(l_tilde, k): 34 | """ 35 | compute a list of chebyshev polynomials from T_0 to T_{K-1} 36 | 37 | Args: 38 | l_tilde(scipy.sparse.coo.coo_matrix): scaled Laplacian, shape (N, N) 39 | k(int): the maximum order of chebyshev polynomials 40 | 41 | Returns: 42 | list(np.ndarray): cheb_polynomials, length: K, from T_0 to T_{K-1} 43 | """ 44 | l_tilde = sp.coo_matrix(l_tilde) 45 | num = l_tilde.shape[0] 46 | cheb_polynomials = [sp.eye(num).tocoo(), l_tilde.copy()] 47 | for i in range(2, k + 1): 48 | cheb_i = (2 * l_tilde).dot(cheb_polynomials[i - 1]) - cheb_polynomials[i - 2] 49 | cheb_polynomials.append(cheb_i.tocoo()) 50 | return cheb_polynomials 51 | 52 | 53 | def get_supports_matrix(adj_mx, filter_type='laplacian', undirected=True): 54 | """ 55 | 选择不同类别的拉普拉斯 56 | 57 | Args: 58 | undirected: 59 | adj_mx: 60 | filter_type: 61 | 62 | Returns: 63 | 64 | """ 65 | supports = [] 66 | if filter_type == "laplacian": 67 | supports.append(calculate_scaled_laplacian(adj_mx, lambda_max=None, undirected=undirected)) 68 | elif filter_type == "random_walk": 69 | supports.append(calculate_random_walk_matrix(adj_mx).T) 70 | elif filter_type == "dual_random_walk": 71 | supports.append(calculate_random_walk_matrix(adj_mx).T) 72 | supports.append(calculate_random_walk_matrix(adj_mx.T).T) 73 | else: 74 | supports.append(calculate_scaled_laplacian(adj_mx)) 75 | return supports 76 | 77 | 78 | def calculate_normalized_laplacian(adj): 79 | """ 80 | L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2 81 | 对称归一化的拉普拉斯 82 | 83 | Args: 84 | adj: adj matrix 85 | 86 | Returns: 87 | np.ndarray: L 88 | """ 89 | adj = sp.coo_matrix(adj) 90 | d = np.array(adj.sum(1)) 91 | d_inv_sqrt = np.power(d, -0.5).flatten() 92 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 93 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 94 | normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 95 | return normalized_laplacian 96 | 97 | 98 | def calculate_random_walk_matrix(adj_mx): 99 | """ 100 | L = D^-1 * A 101 | 随机游走拉普拉斯 102 | 103 | Args: 104 | adj_mx: adj matrix 105 | 106 | Returns: 107 | np.ndarray: L 108 | """ 109 | adj_mx = sp.coo_matrix(adj_mx) 110 | d = np.array(adj_mx.sum(1)) 111 | d_inv = np.power(d, -1).flatten() 112 | d_inv[np.isinf(d_inv)] = 0. 113 | d_mat_inv = sp.diags(d_inv) 114 | random_walk_mx = d_mat_inv.dot(adj_mx).tocoo() 115 | return random_walk_mx 116 | 117 | 118 | def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True): 119 | """ 120 | 计算近似后的拉普莱斯矩阵~L 121 | 122 | Args: 123 | adj_mx: 124 | lambda_max: 125 | undirected: 126 | 127 | Returns: 128 | ~L = 2 * L / lambda_max - I 129 | """ 130 | adj_mx = sp.coo_matrix(adj_mx) 131 | if undirected: 132 | bigger = adj_mx > adj_mx.T 133 | smaller = adj_mx < adj_mx.T 134 | notequall = adj_mx != adj_mx.T 135 | adj_mx = adj_mx - adj_mx.multiply(notequall) + adj_mx.multiply(bigger) + adj_mx.T.multiply(smaller) 136 | lap = calculate_normalized_laplacian(adj_mx) 137 | if lambda_max is None: 138 | lambda_max, _ = linalg.eigsh(lap, 1, which='LM') 139 | lambda_max = lambda_max[0] 140 | lap = sp.csr_matrix(lap) 141 | m, _ = lap.shape 142 | identity = sp.identity(m, format='csr', dtype=lap.dtype) 143 | lap = (2 / lambda_max * lap) - identity 144 | return lap.astype(np.float32).tocoo() 145 | -------------------------------------------------------------------------------- /libcity/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from libcity.pipeline.pipeline import run_model, hyper_parameter, objective_function 2 | 3 | __all__ = [ 4 | "run_model", 5 | "hyper_parameter", 6 | "objective_function" 7 | ] 8 | -------------------------------------------------------------------------------- /libcity/temp/1.4-data_prepare_plot_POI.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import os 5 | import datetime 6 | import glob 7 | import geopandas as gpd 8 | import seaborn as sns 9 | import matplotlib.dates as mdates 10 | import matplotlib as mpl 11 | 12 | pd.options.mode.chained_assignment = None 13 | results_path = r'D:\\ST_Graph\\Data\\' 14 | 15 | plt.rcParams.update( 16 | {'font.size': 13, 'font.family': "serif", 'mathtext.fontset': 'dejavuserif', 'xtick.direction': 'in', 17 | 'xtick.major.size': 0.5, 'grid.linestyle': "--", 'axes.grid': True, "grid.alpha": 1, "grid.color": "#cccccc", 18 | 'xtick.minor.size': 1.5, 'xtick.minor.width': 0.5, 'xtick.minor.visible': True, 'xtick.top': True, 19 | 'ytick.direction': 'in', 'ytick.major.size': 0.5, 'ytick.minor.size': 1.5, 'ytick.minor.width': 0.5, 20 | 'ytick.minor.visible': True, 'ytick.right': True, 'axes.linewidth': 0.5, 'grid.linewidth': 0.5, 21 | 'lines.linewidth': 1.5, 'legend.frameon': False, 'savefig.bbox': 'tight', 'savefig.pad_inches': 0.05}) 22 | 23 | # Get county subdivision 24 | f_na, f_nas = 'SG_CTS_Hourly', 'SG_CTS_Hourly_Single' 25 | POI_Type = ['Education', 'Others', 'Recreation', 'Residential', 'Restaurant', 'Retail', 'Service'] 26 | 27 | # Dynamic 28 | Dyna = pd.read_csv(results_path + r'Lib_Data\old\%s\%s.dyna' % (f_na, f_na)) 29 | Dyna['time'] = pd.to_datetime(Dyna['time']) 30 | Dyna['date'] = Dyna['time'].dt.date 31 | Dyna['hour'] = Dyna['time'].dt.hour 32 | Dyna['dayofweek'] = Dyna['time'].dt.dayofweek 33 | 34 | # Plot time series 35 | # mpl.rcParams['axes.prop_cycle'] = plt.cycler("color", plt.cm.coolwarm(np.linspace(0, 1, 7))) 36 | mpl.rcParams['axes.prop_cycle'] = plt.cycler("color", plt.cm.Set2.colors) 37 | fig, ax = plt.subplots(figsize=(10, 6)) 38 | Dyna.groupby(['date'])[POI_Type].sum().plot(ax=ax, lw=1.5) 39 | ax.set_xlabel('Date') 40 | ax.set_ylabel('Population Inflow') 41 | plt.legend(ncol=3) 42 | ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) 43 | ax.xaxis.set_major_locator(mdates.MonthLocator(interval=2)) 44 | ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0), useMathText=True) 45 | plt.tight_layout() 46 | plt.savefig(r'D:\ST_Graph\Figures\Daily_pattern.png', dpi=1000) 47 | plt.close() 48 | 49 | fig, ax = plt.subplots(figsize=(10, 6)) 50 | Dyna.groupby(['dayofweek', 'hour'])[POI_Type].sum().plot(ax=ax, lw=1.5) 51 | ax.set_xlabel('Dayofweek, Hour') 52 | ax.set_ylabel('Population Inflow') 53 | plt.legend(ncol=3) 54 | ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0), useMathText=True) 55 | plt.tight_layout() 56 | plt.savefig(r'D:\ST_Graph\Figures\Hourly_pattern.png', dpi=1000) 57 | plt.close() 58 | 59 | # Plot time series for each CTS 60 | Dyna['All'] = Dyna[POI_Type].sum(axis=1) 61 | Dyna = Dyna.rename({'entity_id': 'CTSFIPS'}, axis=1) 62 | Dyna['CTSFIPS'] = Dyna['CTSFIPS'].astype(str) 63 | tempfile = Dyna.groupby(['CTSFIPS', 'date'])['All'].sum().reset_index() 64 | fig, ax = plt.subplots(figsize=(10, 6)) 65 | for kk in set(tempfile['CTSFIPS']): 66 | ax.plot(tempfile.loc[tempfile['CTSFIPS'] == kk, 'date'], tempfile.loc[tempfile['CTSFIPS'] == kk, 'All']) 67 | ax.set_xlabel('Date') 68 | ax.set_ylabel('Population Inflow') 69 | ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) 70 | ax.xaxis.set_major_locator(mdates.MonthLocator(interval=2)) 71 | ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0), useMathText=True) 72 | plt.tight_layout() 73 | plt.savefig(r'D:\ST_Graph\Figures\Daily_pattern_CTS.png', dpi=1000) 74 | plt.close() 75 | 76 | # Geo: PA and OD 77 | # Get PA for each CTS 78 | CTS_Info = pd.read_pickle(r'D:\ST_Graph\Results\CTS_Info.pkl') 79 | CTS_Info['x'] = CTS_Info.centroid.x 80 | CTS_Info['y'] = CTS_Info.centroid.y 81 | data_raw_avg = Dyna.groupby(['CTSFIPS'])[['All'] + POI_Type].mean().reset_index() 82 | poly = CTS_Info.merge(data_raw_avg, on='CTSFIPS') 83 | 84 | # Get OD 85 | CTS_OD = pd.read_csv(results_path + r'Lib_Data\%s\%s.rel' % (f_na, f_na)) 86 | CTS_OD['origin_id'] = CTS_OD['origin_id'].astype(str) 87 | CTS_OD['destination_id'] = CTS_OD['destination_id'].astype(str) 88 | bmc_zone = CTS_Info[['CTSFIPS', 'x', 'y']] 89 | bmc_zone.columns = ['origin_id', 'O_Lng', 'O_Lat'] 90 | CTS_OD = CTS_OD.merge(bmc_zone, on='origin_id') 91 | bmc_zone.columns = ['destination_id', 'D_Lng', 'D_Lat'] 92 | CTS_OD = CTS_OD.merge(bmc_zone, on='destination_id') 93 | 94 | # Calculate for other POI types 95 | CTS_OD_All = pd.concat(map(pd.read_pickle, glob.glob(os.path.join(r'D:\ST_Graph\Data\SG_PC', 'CTS_OD_Weekly_*.pkl')))) 96 | CTS_OD_All.rename({'CTSFIPS_O': 'origin_id', 'CTSFIPS_D': 'destination_id'}, axis=1, inplace=True) 97 | t_s = datetime.datetime(2019, 1, 1) 98 | t_e = datetime.datetime(2020, 3, 1) 99 | CTS_OD_All['Time'] = pd.to_datetime(CTS_OD_All['Time']) 100 | CTS_OD_All = CTS_OD_All[(CTS_OD_All['Time'] < t_e) & (CTS_OD_All['Time'] >= t_s)].reset_index(drop=True) 101 | CTS_OD_All['All'] = CTS_OD_All[POI_Type].sum(axis=1) 102 | 103 | for kk in POI_Type + ['All']: 104 | CTS_OD_s = CTS_OD_All[['origin_id', 'destination_id'] + [kk]] 105 | CTS_OD_s = CTS_OD_s.groupby(['origin_id', 'destination_id']).sum().reset_index() 106 | CTS_D = CTS_OD_s.groupby(['destination_id'])[kk].sum().reset_index() 107 | CTS_D.columns = ['destination_id', 'Inflow'] 108 | CTS_OD_s = CTS_OD_s.merge(CTS_D, on='destination_id') 109 | CTS_OD_s[kk] = CTS_OD_s[kk] / CTS_OD_s['Inflow'] 110 | CTS_OD_s.drop(['Inflow'], axis=1, inplace=True) 111 | CTS_OD = CTS_OD.merge(CTS_OD_s, on=['origin_id', 'destination_id'], how='left') 112 | CTS_OD = CTS_OD[CTS_OD['origin_id'] != CTS_OD['destination_id']].reset_index(drop=True) 113 | CTS_OD = CTS_OD.fillna(0) 114 | 115 | for p1 in POI_Type + ['All']: 116 | fig, ax = plt.subplots(figsize=(9, 7)) 117 | poly.geometry.boundary.plot(color=None, edgecolor='k', linewidth=1, ax=ax) 118 | # Plot PA 119 | poly.plot(column=p1, ax=ax, legend=True, scheme='UserDefined', cmap='coolwarm', linewidth=0, edgecolor='white', 120 | classification_kwds=dict(bins=[np.quantile(poly[p1], 1 / 6), np.quantile(poly[p1], 2 / 6), 121 | np.quantile(poly[p1], 3 / 6), np.quantile(poly[p1], 4 / 6), 122 | np.quantile(poly[p1], 5 / 6)]), 123 | legend_kwds=dict(frameon=False, ncol=3)) 124 | ax.tick_params(left=False, labelleft=False, bottom=False, labelbottom=False) 125 | ax.axis('off') 126 | ax.set_title('Average Hourly Population Inflow (%s)' % p1, pad=-0) 127 | 128 | # Reset Legend 129 | patch_col = ax.get_legend() 130 | patch_col.set_bbox_to_anchor((1, 0.05)) 131 | legend_labels = ax.get_legend().get_texts() 132 | for bound, legend_label in \ 133 | zip(['< ' + str(round(np.quantile(poly[p1], 1 / 6), 2)), 134 | str(round(np.quantile(poly[p1], 1 / 6), 2)) + ' - ' + str(round(np.quantile(poly[p1], 2 / 6), 2)), 135 | str(round(np.quantile(poly[p1], 2 / 6), 2)) + ' - ' + str(round(np.quantile(poly[p1], 3 / 6), 2)), 136 | str(round(np.quantile(poly[p1], 3 / 6), 2)) + ' - ' + str(round(np.quantile(poly[p1], 4 / 6), 2)), 137 | str(round(np.quantile(poly[p1], 4 / 6), 2)) + ' - ' + str(round(np.quantile(poly[p1], 5 / 6), 2)), 138 | '> ' + str(round(np.quantile(poly[p1], 5 / 6), 2))], legend_labels): 139 | legend_label.set_text(bound) 140 | plt.subplots_adjust(top=0.938, bottom=0.137, left=0.016, right=0.984, hspace=0.2, wspace=0.11) 141 | 142 | # Plot OD 143 | Cn = CTS_OD[CTS_OD[p1] > 0.001].reset_index(drop=True) 144 | for kk in range(0, len(Cn)): 145 | ax.annotate('', xy=(Cn.loc[kk, 'O_Lng'], Cn.loc[kk, 'O_Lat']), 146 | xytext=(Cn.loc[kk, 'D_Lng'], Cn.loc[kk, 'D_Lat']), 147 | arrowprops={'arrowstyle': '-', 'lw': Cn.loc[kk, p1] * 10, 'color': 'k', 'alpha': 0.8, 148 | 'connectionstyle': "arc3,rad=0.2"}, va='center') 149 | 150 | plt.savefig(r'D:\ST_Graph\Figures\PA_OD_%s.png' % p1, dpi=600) 151 | plt.close() 152 | -------------------------------------------------------------------------------- /libcity/temp/MultiATGCN-3TU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from logging import getLogger 5 | from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel 6 | from libcity.model import loss 7 | 8 | 9 | class FusionLayer(nn.Module): 10 | # Matrix-based fusion 11 | def __init__(self, out_window, n, out_dim): 12 | super(FusionLayer, self).__init__() 13 | # define the trainable parameter 14 | self.weights = nn.Parameter(torch.FloatTensor(1, out_window, n, out_dim)) 15 | 16 | def forward(self, x): 17 | # assuming x is of size B-n-h-w 18 | x = x * self.weights # element-wise multiplication 19 | return x 20 | 21 | 22 | class AVWGCN(nn.Module): 23 | def __init__(self, dim_in, dim_out, cheb_k, embed_dim): 24 | super(AVWGCN, self).__init__() 25 | self.cheb_k = cheb_k 26 | self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)) 27 | self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) 28 | 29 | def forward(self, x, node_embeddings): 30 | # x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N] 31 | # output shape [B, N, C] 32 | node_num = node_embeddings.shape[0] # node_embeddings: E 33 | supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1) # A~ 34 | support_set = [torch.eye(node_num).to(supports.device), supports] 35 | # default cheb_k = 3 36 | for k in range(2, self.cheb_k): 37 | support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2]) 38 | supports = torch.stack(support_set, dim=0) 39 | weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool) # N, cheb_k, dim_in, dim_out 40 | bias = torch.matmul(node_embeddings, self.bias_pool) # N, dim_out 41 | x_g = torch.einsum("knm,bmc->bknc", supports, x) # B, cheb_k, N, dim_in 42 | x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in 43 | x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias # b, N, dim_out 44 | return x_gconv 45 | 46 | 47 | class AGCRNCell(nn.Module): 48 | def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim): 49 | super(AGCRNCell, self).__init__() 50 | self.node_num = node_num 51 | self.hidden_dim = dim_out 52 | self.gate = AVWGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim) 53 | self.update = AVWGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim) 54 | 55 | def forward(self, x, state, node_embeddings): 56 | # x: B, num_nodes, input_dim 57 | # state: B, num_nodes, hidden_dim 58 | state = state.to(x.device) 59 | input_and_state = torch.cat((x, state), dim=-1) 60 | z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings)) 61 | z, r = torch.split(z_r, self.hidden_dim, dim=-1) 62 | candidate = torch.cat((x, z * state), dim=-1) 63 | hc = torch.tanh(self.update(candidate, node_embeddings)) 64 | h = r * state + (1 - r) * hc 65 | return h 66 | 67 | def init_hidden_state(self, batch_size): 68 | return torch.zeros(batch_size, self.node_num, self.hidden_dim) 69 | 70 | 71 | class AVWDCRNN(nn.Module): 72 | def __init__(self, config): 73 | super(AVWDCRNN, self).__init__() 74 | self.num_nodes = config['num_nodes'] 75 | self.feature_dim = config['feature_dim'] 76 | self.hidden_dim = config.get('rnn_units', 64) 77 | self.embed_dim = config.get('embed_dim', 10) 78 | self.num_layers = config.get('num_layers', 2) 79 | self.cheb_k = config.get('cheb_order', 2) 80 | self.input_window = config.get('input_window', 1) 81 | self.output_window = config.get('output_window', 1) 82 | self.output_dim = config.get('output_dim', 1) 83 | assert self.num_layers >= 1, 'At least one DCRNN layer in the Encoder.' 84 | 85 | self.dcrnn_cells = nn.ModuleList() 86 | self.dcrnn_cells.append(AGCRNCell(self.num_nodes, self.feature_dim, 87 | self.hidden_dim, self.cheb_k, self.embed_dim)) 88 | for _ in range(1, self.num_layers): 89 | self.dcrnn_cells.append(AGCRNCell(self.num_nodes, self.hidden_dim, 90 | self.hidden_dim, self.cheb_k, self.embed_dim)) 91 | 92 | self.ln = nn.LayerNorm(self.hidden_dim) 93 | # self.end_conv = nn.Conv2d(1, self.output_window, kernel_size=(1, self.hidden_dim), bias=True) 94 | self.end_conv = nn.Conv2d(self.input_window, self.output_window, kernel_size=(1, self.hidden_dim), bias=True) 95 | self.fusionlayer = FusionLayer(self.output_window, self.num_nodes, self.output_dim) 96 | 97 | def forward(self, x, init_state, node_embeddings): 98 | # shape of x: (B, T, N, D) 99 | # shape of init_state: (num_layers, B, N, hidden_dim) 100 | assert x.shape[2] == self.num_nodes 101 | seq_length = x.shape[1] 102 | current_inputs = x 103 | output_hidden = [] 104 | for i in range(self.num_layers): 105 | state = init_state[i] 106 | inner_states = [] 107 | for t in range(seq_length): 108 | state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state, node_embeddings) # B, N, hidden_dim 109 | inner_states.append(state) 110 | output_hidden.append(state) 111 | current_inputs = torch.stack(inner_states, dim=1) 112 | # current_inputs: the outputs of last layer: (B, T, N, hidden_dim) 113 | # output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim) 114 | # last_state: (B, N, hidden_dim) 115 | # output_hidden = self.ln(output_hidden[-1]) 116 | output_hidden = self.end_conv(current_inputs) # B, T*C, N, 1 117 | # output_hidden = self.end_conv(output_hidden[-1].unsqueeze(dim=1)) 118 | output_hidden = self.fusionlayer(output_hidden) 119 | return output_hidden 120 | # return output_hidden[-1].unsqueeze(dim=1) 121 | 122 | def init_hidden(self, batch_size): 123 | init_states = [] 124 | for i in range(self.num_layers): 125 | init_states.append(self.dcrnn_cells[i].init_hidden_state(batch_size)) 126 | return torch.stack(init_states, dim=0) # (num_layers, B, N, hidden_dim) 127 | 128 | 129 | class MultiATGCN(AbstractTrafficStateModel): 130 | def __init__(self, config, data_feature): 131 | self.num_nodes = data_feature.get('num_nodes', 1) 132 | self.feature_dim = data_feature.get('feature_dim', 1) 133 | config['num_nodes'] = self.num_nodes 134 | config['feature_dim'] = self.feature_dim 135 | 136 | super().__init__(config, data_feature) 137 | self.input_window = config.get('input_window', 1) 138 | self.output_window = config.get('output_window', 1) 139 | self.output_dim = self.data_feature.get('output_dim', 1) 140 | self.hidden_dim = config.get('rnn_units', 64) 141 | self.embed_dim = config.get('embed_dim', 10) 142 | self.len_period = self.data_feature.get('len_period', 0) 143 | self.len_trend = self.data_feature.get('len_trend', 0) 144 | self.len_closeness = self.data_feature.get('len_closeness', 0) 145 | 146 | self.nembed_close = nn.Parameter(torch.randn(self.num_nodes, self.embed_dim), requires_grad=True) 147 | self.encoder_close = AVWDCRNN(config) 148 | self.nembed_period = nn.Parameter(torch.randn(self.num_nodes, self.embed_dim), requires_grad=True) 149 | self.encoder_period = AVWDCRNN(config) 150 | self.nembed_trend = nn.Parameter(torch.randn(self.num_nodes, self.embed_dim), requires_grad=True) 151 | self.encoder_trend = AVWDCRNN(config) 152 | # self.ln = nn.LayerNorm(self.hidden_dim) 153 | 154 | self.device = config.get('device', torch.device('cpu')) 155 | self._logger = getLogger() 156 | self._scaler = self.data_feature.get('scaler') 157 | self._init_parameters() 158 | 159 | def _init_parameters(self): 160 | for p in self.parameters(): 161 | if p.dim() > 1: 162 | nn.init.xavier_uniform_(p) 163 | else: 164 | nn.init.uniform_(p) 165 | 166 | def forward(self, batch): 167 | # source: B, T_1, N, F_D 168 | # target: B, T_2, N, F_D 169 | source = batch['X'] 170 | output = 0 171 | init_state = self.encoder_close.init_hidden(source.shape[0]) 172 | if self.len_closeness > 0: 173 | begin_index = 0 174 | end_index = begin_index + self.len_closeness 175 | output_hours = self.encoder_close(source[:, begin_index:end_index, :, :], init_state, self.nembed_close) 176 | output += output_hours 177 | if self.len_period > 0: 178 | begin_index = self.len_closeness 179 | end_index = begin_index + self.len_period 180 | output_days = self.encoder_period(source[:, begin_index:end_index, :, :], init_state, self.nembed_period) 181 | output += output_days 182 | if self.len_trend > 0: 183 | begin_index = self.len_closeness + self.len_period 184 | end_index = begin_index + self.len_trend 185 | output_weeks = self.encoder_trend(source[:, begin_index:end_index, :, :], init_state, self.nembed_trend) 186 | output += output_weeks 187 | return output 188 | 189 | def calculate_loss(self, batch): 190 | y_true = batch['y'] 191 | y_predicted = self.predict(batch) 192 | y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim]) 193 | y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim]) 194 | # y_true = y_true.sum(-1, keepdims=True) 195 | return loss.masked_mae_torch(y_predicted, y_true, 0) 196 | 197 | def predict(self, batch): 198 | return self.forward(batch) 199 | -------------------------------------------------------------------------------- /libcity/temp/MultiATGCN-3TUSimple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from logging import getLogger 5 | from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel 6 | from libcity.model import loss 7 | 8 | 9 | class AVWGCN(nn.Module): 10 | def __init__(self, dim_in, dim_out, cheb_k, embed_dim): 11 | super(AVWGCN, self).__init__() 12 | self.cheb_k = cheb_k 13 | self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)) 14 | self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) 15 | 16 | def forward(self, x, node_embeddings): 17 | # x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N] 18 | # output shape [B, N, C] 19 | node_num = node_embeddings.shape[0] # node_embeddings: E 20 | supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1) # A~ 21 | support_set = [torch.eye(node_num).to(supports.device), supports] 22 | # default cheb_k = 3 23 | for k in range(2, self.cheb_k): 24 | support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2]) 25 | supports = torch.stack(support_set, dim=0) 26 | weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool) # N, cheb_k, dim_in, dim_out 27 | bias = torch.matmul(node_embeddings, self.bias_pool) # N, dim_out 28 | x_g = torch.einsum("knm,bmc->bknc", supports, x) # B, cheb_k, N, dim_in 29 | x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in 30 | x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias # b, N, dim_out 31 | return x_gconv 32 | 33 | 34 | class AGCRNCell(nn.Module): 35 | def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim): 36 | super(AGCRNCell, self).__init__() 37 | self.node_num = node_num 38 | self.hidden_dim = dim_out 39 | self.gate = AVWGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim) 40 | self.update = AVWGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim) 41 | 42 | def forward(self, x, state, node_embeddings): 43 | # x: B, num_nodes, input_dim 44 | # state: B, num_nodes, hidden_dim 45 | state = state.to(x.device) 46 | input_and_state = torch.cat((x, state), dim=-1) 47 | z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings)) 48 | z, r = torch.split(z_r, self.hidden_dim, dim=-1) 49 | candidate = torch.cat((x, z * state), dim=-1) 50 | hc = torch.tanh(self.update(candidate, node_embeddings)) 51 | h = r * state + (1 - r) * hc 52 | return h 53 | 54 | def init_hidden_state(self, batch_size): 55 | return torch.zeros(batch_size, self.node_num, self.hidden_dim) 56 | 57 | 58 | class AVWDCRNN(nn.Module): 59 | def __init__(self, config): 60 | super(AVWDCRNN, self).__init__() 61 | self.num_nodes = config['num_nodes'] 62 | self.feature_dim = config['feature_dim'] 63 | self.hidden_dim = config.get('rnn_units', 64) 64 | self.embed_dim = config.get('embed_dim', 10) 65 | self.num_layers = config.get('num_layers', 2) 66 | self.cheb_k = config.get('cheb_order', 2) 67 | assert self.num_layers >= 1, 'At least one DCRNN layer in the Encoder.' 68 | 69 | self.dcrnn_cells = nn.ModuleList() 70 | self.dcrnn_cells.append(AGCRNCell(self.num_nodes, self.feature_dim, 71 | self.hidden_dim, self.cheb_k, self.embed_dim)) 72 | for _ in range(1, self.num_layers): 73 | self.dcrnn_cells.append(AGCRNCell(self.num_nodes, self.hidden_dim, 74 | self.hidden_dim, self.cheb_k, self.embed_dim)) 75 | 76 | def forward(self, x, init_state, node_embeddings): 77 | # shape of x: (B, T, N, D) 78 | # shape of init_state: (num_layers, B, N, hidden_dim) 79 | assert x.shape[2] == self.num_nodes and x.shape[3] == self.feature_dim 80 | seq_length = x.shape[1] 81 | current_inputs = x 82 | output_hidden = [] 83 | for i in range(self.num_layers): 84 | state = init_state[i] 85 | inner_states = [] 86 | for t in range(seq_length): 87 | state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state, node_embeddings) 88 | inner_states.append(state) 89 | output_hidden.append(state) 90 | current_inputs = torch.stack(inner_states, dim=1) 91 | # current_inputs: the outputs of last layer: (B, T, N, hidden_dim) 92 | # output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim) 93 | # last_state: (B, N, hidden_dim) 94 | return current_inputs, output_hidden 95 | 96 | def init_hidden(self, batch_size): 97 | init_states = [] 98 | for i in range(self.num_layers): 99 | init_states.append(self.dcrnn_cells[i].init_hidden_state(batch_size)) 100 | return torch.stack(init_states, dim=0) # (num_layers, B, N, hidden_dim) 101 | 102 | 103 | class MultiATGCN(AbstractTrafficStateModel): 104 | def __init__(self, config, data_feature): 105 | self.num_nodes = data_feature.get('num_nodes', 1) 106 | self.feature_dim = data_feature.get('feature_dim', 1) 107 | config['num_nodes'] = self.num_nodes 108 | config['feature_dim'] = self.feature_dim 109 | 110 | super().__init__(config, data_feature) 111 | self.input_window = config.get('input_window', 1) 112 | self.output_window = config.get('output_window', 1) 113 | self.output_dim = self.data_feature.get('output_dim', 1) 114 | self.hidden_dim = config.get('rnn_units', 64) 115 | self.embed_dim = config.get('embed_dim', 10) 116 | 117 | self.len_period = self.data_feature.get('len_period', 0) 118 | self.len_trend = self.data_feature.get('len_trend', 0) 119 | self.len_closeness = self.data_feature.get('len_closeness', 0) 120 | 121 | self.weight_t1 = nn.Parameter(torch.FloatTensor(1, self.len_closeness, 1, self.feature_dim)) 122 | self.weight_t2 = nn.Parameter(torch.FloatTensor(1, self.len_period, 1, self.feature_dim)) 123 | self.weight_t3 = nn.Parameter(torch.FloatTensor(1, self.len_trend, 1, self.feature_dim)) 124 | 125 | self.node_embeddings = nn.Parameter(torch.randn(self.num_nodes, self.embed_dim), requires_grad=True) 126 | self.encoder = AVWDCRNN(config) 127 | self.end_conv = nn.Conv2d(self.input_window, self.output_window * self.output_dim, 128 | kernel_size=(1, self.hidden_dim), bias=True) 129 | 130 | self.device = config.get('device', torch.device('cpu')) 131 | self._logger = getLogger() 132 | self._scaler = self.data_feature.get('scaler') 133 | self._init_parameters() 134 | 135 | def _init_parameters(self): 136 | for p in self.parameters(): 137 | if p.dim() > 1: 138 | nn.init.xavier_uniform_(p) 139 | else: 140 | nn.init.uniform_(p) 141 | 142 | def forward(self, batch): 143 | # source: B, T_1, N, D 144 | # target: B, T_2, N, D 145 | # supports = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec1.transpose(0,1))), dim=1) 146 | source = batch['X'] 147 | output = 0.0 148 | # Merge three temporal unit 149 | if self.len_closeness > 0: 150 | begin_index = 0 151 | end_index = begin_index + self.len_closeness 152 | output_hours = source[:, begin_index:end_index, :, :] 153 | output += output_hours * self.weight_t1 154 | if self.len_period > 0: 155 | begin_index = self.len_closeness 156 | end_index = begin_index + self.len_period 157 | output_days = source[:, begin_index:end_index, :, :] 158 | output += output_days * self.weight_t2 159 | if self.len_trend > 0: 160 | begin_index = self.len_closeness + self.len_period 161 | end_index = begin_index + self.len_trend 162 | output_weeks = source[:, begin_index:end_index, :, :] 163 | output += output_weeks * self.weight_t3 164 | 165 | init_state = self.encoder.init_hidden(source.shape[0]) 166 | output, _ = self.encoder(output, init_state, self.node_embeddings) # B, T, N, hidden 167 | # output = output[:, -1:, :, :] # B, 1, N, hidden 168 | 169 | # CNN based predictor 170 | output = self.end_conv(output) # B, T*C, N, 1 171 | output = output.squeeze(-1).reshape(-1, self.output_window, self.output_dim, self.num_nodes) 172 | output = output.permute(0, 1, 3, 2) # B, T, N, C 173 | return output 174 | 175 | def calculate_loss(self, batch): 176 | y_true = batch['y'] 177 | y_predicted = self.predict(batch) 178 | y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim]) 179 | y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim]) 180 | # losses = 0.0 181 | # for i in range(7): 182 | # losses += loss.masked_mae_torch(y_predicted[:, :, :, i], y_true[:, :, :, i], 0) 183 | # return losses 184 | return loss.masked_mae_torch(y_predicted, y_true, 0) 185 | 186 | def predict(self, batch): 187 | return self.forward(batch) 188 | -------------------------------------------------------------------------------- /libcity/temp/MultiATGCN-Traffic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from logging import getLogger 5 | from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel 6 | from libcity.model import loss 7 | 8 | 9 | class AVWGCN(nn.Module): 10 | def __init__(self, dim_in, dim_out, cheb_k, embed_dim): 11 | super(AVWGCN, self).__init__() 12 | self.cheb_k = cheb_k 13 | self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)) 14 | self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) 15 | 16 | def forward(self, x, node_embeddings): 17 | # x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N] 18 | # output shape [B, N, C] 19 | node_num = node_embeddings.shape[0] # node_embeddings: E 20 | supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1) # A~ 21 | support_set = [torch.eye(node_num).to(supports.device), supports] 22 | # default cheb_k = 3 23 | for k in range(2, self.cheb_k): 24 | support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2]) 25 | supports = torch.stack(support_set, dim=0) 26 | weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool) # N, cheb_k, dim_in, dim_out 27 | bias = torch.matmul(node_embeddings, self.bias_pool) # N, dim_out 28 | x_g = torch.einsum("knm,bmc->bknc", supports, x) # B, cheb_k, N, dim_in 29 | x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in 30 | x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias # b, N, dim_out 31 | return x_gconv 32 | 33 | 34 | class AGCRNCell(nn.Module): 35 | def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim): 36 | super(AGCRNCell, self).__init__() 37 | self.node_num = node_num 38 | self.hidden_dim = dim_out 39 | self.gate = AVWGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim) 40 | self.update = AVWGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim) 41 | 42 | def forward(self, x, state, node_embeddings): 43 | # x: B, num_nodes, input_dim 44 | # state: B, num_nodes, hidden_dim 45 | state = state.to(x.device) 46 | input_and_state = torch.cat((x, state), dim=-1) 47 | z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings)) 48 | z, r = torch.split(z_r, self.hidden_dim, dim=-1) 49 | candidate = torch.cat((x, z * state), dim=-1) 50 | hc = torch.tanh(self.update(candidate, node_embeddings)) 51 | h = r * state + (1 - r) * hc 52 | return h 53 | 54 | def init_hidden_state(self, batch_size): 55 | return torch.zeros(batch_size, self.node_num, self.hidden_dim) 56 | 57 | 58 | class AVWDCRNN(nn.Module): 59 | def __init__(self, config): 60 | super(AVWDCRNN, self).__init__() 61 | self.num_nodes = config['num_nodes'] 62 | self.feature_dim = config['feature_dim'] 63 | self.hidden_dim = config.get('rnn_units', 64) 64 | self.embed_dim = config.get('embed_dim', 10) 65 | self.num_layers = config.get('num_layers', 2) 66 | self.cheb_k = config.get('cheb_order', 2) 67 | assert self.num_layers >= 1, 'At least one DCRNN layer in the Encoder.' 68 | 69 | self.dcrnn_cells = nn.ModuleList() 70 | self.dcrnn_cells.append(AGCRNCell(self.num_nodes, self.feature_dim, 71 | self.hidden_dim, self.cheb_k, self.embed_dim)) 72 | for _ in range(1, self.num_layers): 73 | self.dcrnn_cells.append(AGCRNCell(self.num_nodes, self.hidden_dim, 74 | self.hidden_dim, self.cheb_k, self.embed_dim)) 75 | 76 | def forward(self, x, init_state, node_embeddings): 77 | # shape of x: (B, T, N, D) 78 | # shape of init_state: (num_layers, B, N, hidden_dim) 79 | assert x.shape[2] == self.num_nodes and x.shape[3] == self.feature_dim 80 | seq_length = x.shape[1] 81 | current_inputs = x 82 | output_hidden = [] 83 | for i in range(self.num_layers): 84 | state = init_state[i] 85 | inner_states = [] 86 | for t in range(seq_length): 87 | state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state, node_embeddings) 88 | inner_states.append(state) 89 | output_hidden.append(state) 90 | current_inputs = torch.stack(inner_states, dim=1) 91 | # current_inputs: the outputs of last layer: (B, T, N, hidden_dim) 92 | # output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim) 93 | # last_state: (B, N, hidden_dim) 94 | return current_inputs, output_hidden 95 | 96 | def init_hidden(self, batch_size): 97 | init_states = [] 98 | for i in range(self.num_layers): 99 | init_states.append(self.dcrnn_cells[i].init_hidden_state(batch_size)) 100 | return torch.stack(init_states, dim=0) # (num_layers, B, N, hidden_dim) 101 | 102 | 103 | class MultiATGCN(AbstractTrafficStateModel): 104 | def __init__(self, config, data_feature): 105 | self.num_nodes = data_feature.get('num_nodes', 1) 106 | self.feature_dim = data_feature.get('feature_dim', 1) 107 | config['num_nodes'] = self.num_nodes 108 | config['feature_dim'] = self.feature_dim 109 | 110 | super().__init__(config, data_feature) 111 | self.input_window = config.get('input_window', 1) 112 | self.output_window = config.get('output_window', 1) 113 | self.output_dim = self.data_feature.get('output_dim', 1) 114 | self.hidden_dim = config.get('rnn_units', 64) 115 | self.embed_dim = config.get('embed_dim', 10) 116 | 117 | self.len_period = self.data_feature.get('len_period', 0) 118 | self.len_trend = self.data_feature.get('len_trend', 0) 119 | self.len_closeness = self.data_feature.get('len_closeness', 0) 120 | 121 | self.weight_t1 = nn.Parameter(torch.FloatTensor(1, self.len_closeness, 1, self.feature_dim)) 122 | self.weight_t2 = nn.Parameter(torch.FloatTensor(1, self.len_period, 1, self.feature_dim)) 123 | self.weight_t3 = nn.Parameter(torch.FloatTensor(1, self.len_trend, 1, self.feature_dim)) 124 | 125 | self.node_embeddings = nn.Parameter(torch.randn(self.num_nodes, self.embed_dim), requires_grad=True) 126 | self.encoder = AVWDCRNN(config) 127 | self.end_conv = nn.Conv2d(self.input_window, self.output_window * self.output_dim, 128 | kernel_size=(1, self.hidden_dim), bias=True) 129 | 130 | self.device = config.get('device', torch.device('cpu')) 131 | self._logger = getLogger() 132 | self._scaler = self.data_feature.get('scaler') 133 | self._init_parameters() 134 | 135 | def _init_parameters(self): 136 | for p in self.parameters(): 137 | if p.dim() > 1: 138 | nn.init.xavier_uniform_(p) 139 | else: 140 | nn.init.uniform_(p) 141 | 142 | def forward(self, batch): 143 | # source: B, T_1, N, D 144 | # target: B, T_2, N, D 145 | # supports = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec1.transpose(0,1))), dim=1) 146 | source = batch['X'] 147 | output = 0.0 148 | # Merge three temporal unit 149 | if self.len_closeness > 0: 150 | begin_index = 0 151 | end_index = begin_index + self.len_closeness 152 | output_hours = source[:, begin_index:end_index, :, :] 153 | output += output_hours * self.weight_t1 154 | if self.len_period > 0: 155 | begin_index = self.len_closeness 156 | end_index = begin_index + self.len_period 157 | output_days = source[:, begin_index:end_index, :, :] 158 | output += output_days * self.weight_t2 159 | if self.len_trend > 0: 160 | begin_index = self.len_closeness + self.len_period 161 | end_index = begin_index + self.len_trend 162 | output_weeks = source[:, begin_index:end_index, :, :] 163 | output += output_weeks * self.weight_t3 164 | 165 | init_state = self.encoder.init_hidden(source.shape[0]) 166 | output, _ = self.encoder(output, init_state, self.node_embeddings) # B, T, N, hidden 167 | # output = output[:, -1:, :, :] # B, 1, N, hidden 168 | 169 | # CNN based predictor 170 | output = self.end_conv(output) # B, T*C, N, 1 171 | output = output.squeeze(-1).reshape(-1, self.output_window, self.output_dim, self.num_nodes) 172 | output = output.permute(0, 1, 3, 2) # B, T, N, C 173 | return output 174 | 175 | def calculate_loss(self, batch): 176 | y_true = batch['y'] 177 | y_predicted = self.predict(batch) 178 | y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim]) 179 | y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim]) 180 | # losses = 0.0 181 | # for i in range(7): 182 | # losses += loss.masked_mae_torch(y_predicted[:, :, :, i], y_true[:, :, :, i], 0) 183 | # return losses 184 | return loss.masked_mae_torch(y_predicted, y_true, 0) 185 | 186 | def predict(self, batch): 187 | return self.forward(batch) 188 | -------------------------------------------------------------------------------- /libcity/temp/MultiATGCN-weather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from logging import getLogger 5 | from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel 6 | from libcity.model import loss 7 | 8 | 9 | class AVWGCN(nn.Module): 10 | def __init__(self, dim_in, dim_out, cheb_k, embed_dim): 11 | super(AVWGCN, self).__init__() 12 | self.cheb_k = cheb_k 13 | self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)) 14 | self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) 15 | 16 | def forward(self, x, node_embeddings): 17 | # x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N] 18 | # output shape [B, N, C] 19 | node_num = node_embeddings.shape[0] # node_embeddings: E 20 | supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1) # A~ 21 | support_set = [torch.eye(node_num).to(supports.device), supports] 22 | # default cheb_k = 3 23 | for k in range(2, self.cheb_k): 24 | support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2]) 25 | supports = torch.stack(support_set, dim=0) 26 | weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool) # N, cheb_k, dim_in, dim_out 27 | bias = torch.matmul(node_embeddings, self.bias_pool) # N, dim_out 28 | x_g = torch.einsum("knm,bmc->bknc", supports, x) # B, cheb_k, N, dim_in 29 | x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in 30 | x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias # b, N, dim_out 31 | return x_gconv 32 | 33 | 34 | class AGCRNCell(nn.Module): 35 | def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim): 36 | super(AGCRNCell, self).__init__() 37 | self.node_num = node_num 38 | self.hidden_dim = dim_out 39 | self.gate = AVWGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim) 40 | self.update = AVWGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim) 41 | 42 | def forward(self, x, state, node_embeddings): 43 | # x: B, num_nodes, input_dim 44 | # state: B, num_nodes, hidden_dim 45 | state = state.to(x.device) 46 | input_and_state = torch.cat((x, state), dim=-1) 47 | z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings)) 48 | z, r = torch.split(z_r, self.hidden_dim, dim=-1) 49 | candidate = torch.cat((x, z * state), dim=-1) 50 | hc = torch.tanh(self.update(candidate, node_embeddings)) 51 | h = r * state + (1 - r) * hc 52 | return h 53 | 54 | def init_hidden_state(self, batch_size): 55 | return torch.zeros(batch_size, self.node_num, self.hidden_dim) 56 | 57 | 58 | class AVWDCRNN(nn.Module): 59 | def __init__(self, config, feature_used): 60 | super(AVWDCRNN, self).__init__() 61 | self.num_nodes = config['num_nodes'] 62 | self.feature_used = feature_used 63 | self.hidden_dim = config.get('rnn_units', 64) 64 | self.embed_dim = config.get('embed_dim', 10) 65 | self.num_layers = config.get('num_layers', 2) 66 | self.cheb_k = config.get('cheb_order', 2) 67 | assert self.num_layers >= 1, 'At least one DCRNN layer in the Encoder.' 68 | 69 | self.dcrnn_cells = nn.ModuleList() 70 | self.dcrnn_cells.append(AGCRNCell(self.num_nodes, self.feature_used, 71 | self.hidden_dim, self.cheb_k, self.embed_dim)) 72 | for _ in range(1, self.num_layers): 73 | self.dcrnn_cells.append(AGCRNCell(self.num_nodes, self.hidden_dim, 74 | self.hidden_dim, self.cheb_k, self.embed_dim)) 75 | 76 | def forward(self, x, init_state, node_embeddings): 77 | # shape of x: (B, T, N, D) 78 | # shape of init_state: (num_layers, B, N, hidden_dim) 79 | assert x.shape[2] == self.num_nodes 80 | seq_length = x.shape[1] 81 | current_inputs = x 82 | output_hidden = [] 83 | for i in range(self.num_layers): 84 | state = init_state[i] 85 | inner_states = [] 86 | for t in range(seq_length): 87 | state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state, node_embeddings) 88 | inner_states.append(state) 89 | output_hidden.append(state) 90 | current_inputs = torch.stack(inner_states, dim=1) 91 | # current_inputs: the outputs of last layer: (B, T, N, hidden_dim) 92 | # output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim) 93 | # last_state: (B, N, hidden_dim) 94 | return current_inputs, output_hidden 95 | 96 | def init_hidden(self, batch_size): 97 | init_states = [] 98 | for i in range(self.num_layers): 99 | init_states.append(self.dcrnn_cells[i].init_hidden_state(batch_size)) 100 | return torch.stack(init_states, dim=0) # (num_layers, B, N, hidden_dim) 101 | 102 | 103 | class MultiATGCN(AbstractTrafficStateModel): 104 | def __init__(self, config, data_feature): 105 | self.num_nodes = data_feature.get('num_nodes', 1) 106 | self.feature_dim = data_feature.get('feature_dim', 1) 107 | self.device = config.get('device', torch.device('cpu')) 108 | 109 | self.static = torch.FloatTensor(data_feature.get('static', None)).to(self.device) 110 | config['num_nodes'] = self.num_nodes 111 | config['feature_dim'] = self.feature_dim 112 | 113 | super().__init__(config, data_feature) 114 | self.input_window = config.get('input_window', 1) 115 | self.output_window = config.get('output_window', 1) 116 | # self.output_dim = self.data_feature.get('output_dim', 1) 117 | self.hidden_dim = config.get('rnn_units', 64) 118 | self.embed_dim = config.get('embed_dim', 10) 119 | 120 | self.ext_dim = self.data_feature.get('ext_dim', 1) 121 | self.start_dim = config.get('start_dim', 0) 122 | self.end_dim = config.get('end_dim', 1) 123 | self.output_dim = self.end_dim - self.start_dim 124 | self.feature_used = self.end_dim - self.start_dim + self.ext_dim 125 | 126 | self.len_period = self.data_feature.get('len_period', 0) 127 | self.len_trend = self.data_feature.get('len_trend', 0) 128 | self.len_closeness = self.data_feature.get('len_closeness', 0) 129 | 130 | self.temporal_emb = config.get('temporal_emb', True) 131 | self.spatial_emb = config.get('spatial_emb', True) 132 | 133 | self.weight_t1 = nn.Parameter(torch.FloatTensor(1, self.len_closeness, 1, self.feature_used)) 134 | self.weight_t2 = nn.Parameter(torch.FloatTensor(1, self.len_period, 1, self.feature_used)) 135 | self.weight_t3 = nn.Parameter(torch.FloatTensor(1, self.len_trend, 1, self.feature_used)) 136 | 137 | self.node_embeddings = nn.Parameter(torch.randn(self.num_nodes, self.embed_dim), requires_grad=True) 138 | self.encoder = AVWDCRNN(config, self.feature_used + 3) 139 | self.end_conv = nn.Conv2d(self.input_window, self.output_window * self.output_dim, 140 | kernel_size=(1, self.hidden_dim), bias=True) 141 | self.end_conv_decoder = nn.Conv2d(self.input_window, self.output_window * self.output_dim, 142 | kernel_size=(1, 3), bias=True) 143 | 144 | self._logger = getLogger() 145 | self._scaler = self.data_feature.get('scaler') 146 | self._init_parameters() 147 | 148 | def _init_parameters(self): 149 | for p in self.parameters(): 150 | if p.dim() > 1: 151 | nn.init.xavier_uniform_(p) 152 | else: 153 | nn.init.uniform_(p) 154 | 155 | def forward(self, batch): 156 | # source: B, T_1, N, D 157 | # target: B, T_2, N, D 158 | source = torch.cat( 159 | (batch['X'][:, :, :, self.start_dim:self.end_dim], batch['X'][:, :, :, -self.ext_dim:]), dim=-1) 160 | 161 | # Merge three temporal unit 162 | output = 0.0 163 | if self.len_closeness > 0: 164 | begin_index = 0 165 | end_index = begin_index + self.len_closeness 166 | output_hours = source[:, begin_index:end_index, :, :] 167 | output += output_hours * self.weight_t1 168 | if self.len_period > 0: 169 | begin_index = self.len_closeness 170 | end_index = begin_index + self.len_period 171 | output_days = source[:, begin_index:end_index, :, :] 172 | output += output_days * self.weight_t2 173 | if self.len_trend > 0: 174 | begin_index = self.len_closeness + self.len_period 175 | end_index = begin_index + self.len_trend 176 | output_weeks = source[:, begin_index:end_index, :, :] 177 | output += output_weeks * self.weight_t3 178 | 179 | output = torch.cat((output, batch['y'][:, :, :, self.end_dim:self.end_dim + 3]), dim=-1) 180 | 181 | # GRU encoder 182 | init_state = self.encoder.init_hidden(source.shape[0]) 183 | output, _ = self.encoder(output, init_state, self.node_embeddings) # B, T, N, hidden 184 | # output = output[:, -1:, :, :] # B, 1, N, hidden 185 | 186 | # CNN based predictor 187 | output = F.dropout(output, p=0.1, training=self.training) 188 | output = self.end_conv(output) # B, T*C, N, 1 189 | # output += decoder_time 190 | output = output.squeeze(-1).reshape(-1, self.output_window, self.output_dim, self.num_nodes) 191 | output = output.permute(0, 1, 3, 2) # B, T, N, C 192 | return output 193 | 194 | def calculate_loss(self, batch): 195 | y_true = batch['y'] 196 | y_predicted = self.predict(batch) 197 | y_true = self._scaler.inverse_transform(y_true[..., self.start_dim:self.end_dim]) 198 | y_predicted = self._scaler.inverse_transform(y_predicted) 199 | return loss.masked_mae_torch(y_predicted, y_true, 0) 200 | 201 | def predict(self, batch): 202 | return self.forward(batch) 203 | -------------------------------------------------------------------------------- /libcity/temp/Seq2Seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | from logging import getLogger 5 | from libcity.model import loss 6 | from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel 7 | 8 | 9 | class Encoder(nn.Module): 10 | def __init__(self, device, rnn_type, input_size, hidden_size=64, 11 | num_layers=1, dropout=0, bidirectional=False): 12 | super().__init__() 13 | self.device = device 14 | self.rnn_type = rnn_type 15 | self.layers = num_layers 16 | self.hidden_size = hidden_size 17 | self.dropout = dropout 18 | if bidirectional: 19 | self.num_directions = 2 20 | else: 21 | self.num_directions = 1 22 | if self.rnn_type.upper() == 'GRU': 23 | self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, 24 | num_layers=num_layers, dropout=dropout, bidirectional=bidirectional) 25 | elif self.rnn_type.upper() == 'LSTM': 26 | self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, 27 | num_layers=num_layers, dropout=dropout, bidirectional=bidirectional) 28 | elif self.rnn_type.upper() == 'RNN': 29 | self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, 30 | num_layers=num_layers, dropout=dropout, bidirectional=bidirectional) 31 | else: 32 | raise ValueError('Unknown RNN type: {}'.format(self.rnn_type)) 33 | 34 | def forward(self, x): 35 | # x = [seq_len, batch_size, input_size] 36 | # h_0 = [layers * num_directions, batch_size, hidden_size] 37 | h_0 = torch.zeros(self.layers * self.num_directions, x.shape[1], self.hidden_size).to(self.device) 38 | if self.rnn_type == 'LSTM': 39 | c_0 = torch.zeros(self.layers * self.num_directions, x.shape[1], self.hidden_size).to(self.device) 40 | out, (hn, cn) = self.rnn(x, (h_0, c_0)) 41 | # output = [seq_len, batch_size, hidden_size * num_directions] 42 | # hn/cn = [layers * num_directions, batch_size, hidden_size] 43 | else: 44 | out, hn = self.rnn(x, h_0) 45 | cn = torch.zeros(hn.shape) 46 | # output = [seq_len, batch_size, hidden_size * num_directions] 47 | # hn = [layers * num_directions, batch_size, hidden_size] 48 | return hn, cn 49 | 50 | 51 | class Decoder(nn.Module): 52 | def __init__(self, device, rnn_type, input_size, hidden_size=64, 53 | num_layers=1, dropout=0, bidirectional=False): 54 | super().__init__() 55 | self.device = device 56 | self.rnn_type = rnn_type 57 | self.layers = num_layers 58 | self.hidden_size = hidden_size 59 | self.dropout = dropout 60 | if bidirectional: 61 | self.num_directions = 2 62 | else: 63 | self.num_directions = 1 64 | if self.rnn_type.upper() == 'GRU': 65 | self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, 66 | num_layers=num_layers, dropout=dropout, bidirectional=bidirectional) 67 | elif self.rnn_type.upper() == 'LSTM': 68 | self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, 69 | num_layers=num_layers, dropout=dropout, bidirectional=bidirectional) 70 | elif self.rnn_type.upper() == 'RNN': 71 | self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, 72 | num_layers=num_layers, dropout=dropout, bidirectional=bidirectional) 73 | else: 74 | raise ValueError('Unknown RNN type: {}'.format(self.rnn_type)) 75 | self.fc = nn.Linear(hidden_size * self.num_directions, input_size) 76 | 77 | def forward(self, x, hn, cn): 78 | # x = [batch_size, input_size] 79 | # hn, cn = [layers * num_directions, batch_size, hidden_size] 80 | x = x.unsqueeze(0) 81 | # x = [seq_len = 1, batch_size, input_size] 82 | if self.rnn_type == 'LSTM': 83 | out, (hn, cn) = self.rnn(x, (hn, cn)) 84 | else: 85 | out, hn = self.rnn(x, hn) 86 | cn = torch.zeros(hn.shape) 87 | # out = [seq_len = 1, batch_size, hidden_size * num_directions] 88 | # hn = [layers * num_directions, batch_size, hidden_size] 89 | out = self.fc(out.squeeze(0)) 90 | # out = [batch_size, input_size] 91 | return out, hn, cn 92 | 93 | 94 | class Seq2Seq(AbstractTrafficStateModel): 95 | def __init__(self, config, data_feature): 96 | super().__init__(config, data_feature) 97 | self._scaler = self.data_feature.get('scaler') 98 | self.num_nodes = self.data_feature.get('num_nodes', 1) 99 | self.feature_dim = self.data_feature.get('feature_dim', 1) 100 | self.output_dim = self.data_feature.get('output_dim', 1) 101 | 102 | self.input_window = config.get('input_window', 1) 103 | self.output_window = config.get('output_window', 1) 104 | self.device = config.get('device', torch.device('cpu')) 105 | self._logger = getLogger() 106 | self._scaler = self.data_feature.get('scaler') 107 | 108 | self.rnn_type = config.get('rnn_type', 'GRU') 109 | self.hidden_size = config.get('hidden_size', 64) 110 | self.num_layers = config.get('num_layers', 1) 111 | self.dropout = config.get('dropout', 0) 112 | self.bidirectional = config.get('bidirectional', False) 113 | self.teacher_forcing_ratio = config.get('teacher_forcing_ratio', 0) 114 | self.encoder = Encoder(self.device, self.rnn_type, self.num_nodes * self.feature_dim, 115 | self.hidden_size, self.num_layers, self.dropout, self.bidirectional) 116 | self.decoder = Decoder(self.device, self.rnn_type, self.num_nodes * self.output_dim, 117 | self.hidden_size, self.num_layers, self.dropout, self.bidirectional) 118 | self._logger.info('You select rnn_type {} in Seq2Seq!'.format(self.rnn_type)) 119 | 120 | def forward(self, batch): 121 | src = batch['X'] # [batch_size, input_window, num_nodes, feature_dim] 122 | target = batch['y'] # [batch_size, output_window, num_nodes, feature_dim] 123 | src = src.permute(1, 0, 2, 3) # [input_window, batch_size, num_nodes, feature_dim] 124 | target = target.permute(1, 0, 2, 3) # [output_window, batch_size, num_nodes, feature_dim] 125 | 126 | batch_size = src.shape[1] 127 | src = src.reshape(self.input_window, batch_size, self.num_nodes * self.feature_dim) 128 | target = target[..., :self.output_dim].contiguous().reshape( 129 | self.output_window, batch_size, self.num_nodes * self.output_dim) 130 | # src = [self.input_window, batch_size, self.num_nodes * self.feature_dim] 131 | # target = [self.output_window, batch_size, self.num_nodes * self.output_dim] 132 | 133 | encoder_hn, encoder_cn = self.encoder(src) 134 | decoder_hn = encoder_hn 135 | decoder_cn = encoder_cn 136 | # encoder_hidden_state = [layers * num_directions, batch_size, hidden_size] 137 | decoder_input = torch.randn(batch_size, self.num_nodes * self.output_dim).to(self.device) 138 | # decoder_input = [batch_size, self.num_nodes * self.output_dim] 139 | 140 | outputs = [] 141 | for i in range(self.output_window): 142 | decoder_output, decoder_hn, decoder_cn = \ 143 | self.decoder(decoder_input, decoder_hn, decoder_cn) 144 | # decoder_output = [batch_size, self.num_nodes * self.output_dim] 145 | # decoder_hn = [layers * num_directions, batch_size, hidden_size] 146 | outputs.append(decoder_output.reshape(batch_size, self.num_nodes, self.output_dim)) 147 | # 只有训练的时候才考虑用真值 148 | if self.training and random.random() < self.teacher_forcing_ratio: 149 | decoder_input = target[i] 150 | else: 151 | decoder_input = decoder_output 152 | outputs = torch.stack(outputs) 153 | # outputs = [self.output_window, batch_size, self.num_nodes, self.output_dim] 154 | return outputs.permute(1, 0, 2, 3) 155 | 156 | def calculate_loss(self, batch): 157 | y_true = batch['y'] 158 | y_predicted = self.predict(batch) 159 | y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim]) 160 | y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim]) 161 | return loss.masked_mae_torch(y_predicted, y_true, 0) 162 | 163 | def predict(self, batch): 164 | return self.forward(batch) 165 | -------------------------------------------------------------------------------- /libcity/temp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SonghuaHu-UMD/MultiSTGraph/96fab1754905336cca08cc83200b6e81f3aa6c43/libcity/temp/__init__.py -------------------------------------------------------------------------------- /libcity/temp/result_convert.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import glob 4 | import os 5 | from libcity.model import loss 6 | from sklearn.metrics import r2_score, explained_variance_score 7 | import datetime 8 | 9 | pd.options.mode.chained_assignment = None 10 | results_path = '.\libcity\cache\\*' 11 | 12 | # Give a dir and read all files inside the dir 13 | def get_gp_data(filenames): 14 | filenames = [ec for ec in filenames if 'log' not in ec] 15 | all_results = pd.DataFrame() 16 | for ec in filenames: 17 | nec = glob.glob(ec + '\\evaluate_cache\\*.csv') 18 | model_name = glob.glob(ec + '\\model_cache\\*.m') 19 | if len(nec) > 0: 20 | fec = pd.read_csv(nec[0]) 21 | fec['Model_name'] = model_name[0].split('\\')[-1].split('_')[0] 22 | fec['Model_time'] = datetime.datetime.fromtimestamp(os.path.getmtime(nec[0])) 23 | all_results = all_results.append(fec) 24 | all_results = all_results.reset_index() 25 | return all_results 26 | 27 | 28 | def transfer_gp_data(filenames, ct_visit_mstd, s_small=10): 29 | m_m = [] 30 | for kk in filenames: 31 | print(kk) 32 | filename = glob.glob(kk + r"\\evaluate_cache\*.npz") 33 | model_name = glob.glob(kk + '\\model_cache\\*.m') 34 | if len(model_name) > 0: 35 | model_name = model_name[0].split('\\')[-1].split('_')[0] 36 | print(model_name) 37 | Predict_R = np.load(filename[0]) 38 | sh = Predict_R['prediction'].shape 39 | print(sh) # no of batches, output_window, no of nodes, output dim 40 | ct_ma = np.tile(ct_visit_mstd[['All_m']].values, (sh[0], sh[1], 1, sh[3])) 41 | ct_sa = np.tile(ct_visit_mstd[['All_std']].values, (sh[0], sh[1], 1, sh[3])) 42 | ct_id = np.tile(ct_visit_mstd[['CTractFIPS']].values, (sh[0], sh[1], 1, sh[3])) 43 | ahead_step = np.tile(np.expand_dims(np.array(range(0, sh[1])), axis=(1, 2)), (sh[0], 1, sh[2], sh[3])) 44 | P_R = pd.DataFrame({'prediction': Predict_R['prediction'].flatten(), 'truth': Predict_R['truth'].flatten(), 45 | 'All_m': ct_ma.flatten(), 'All_std': ct_sa.flatten(), 'CTractFIPS': ct_id.flatten(), 46 | 'ahead_step': ahead_step.flatten()}) 47 | P_R['prediction_t'] = P_R['prediction'] * P_R['All_std'] + P_R['All_m'] 48 | P_R['truth_t'] = P_R['truth'] * P_R['All_std'] + P_R['All_m'] 49 | P_R.loc[P_R['prediction_t'] < 0, 'prediction_t'] = 0 50 | # not consider small volume 51 | for rr in range(0, sh[1]): 52 | pr = P_R.loc[(P_R['ahead_step'] == rr) & (P_R['truth_t'] > s_small), 'prediction_t'] 53 | tr = P_R.loc[(P_R['ahead_step'] == rr) & (P_R['truth_t'] > s_small), 'truth_t'] 54 | m_m.append([model_name, rr, datetime.datetime.fromtimestamp(os.path.getmtime(filename[0])), 55 | loss.masked_mae_np(pr, tr), loss.masked_mse_np(pr, tr), loss.masked_rmse_np(pr, tr), 56 | r2_score(pr, tr), explained_variance_score(pr, tr), loss.masked_mape_np(pr, tr)]) 57 | else: 58 | print(kk + '----NULL----') 59 | return m_m 60 | 61 | 62 | data_name = '201901010601_DC' 63 | filenames = glob.glob(results_path) 64 | filenames = [var for var in filenames if 'dataset_cache' not in var] 65 | all_results = get_gp_data(filenames) 66 | if len(all_results) > 0: 67 | all_results_avg = all_results.groupby(['Model_name']).mean().sort_values(by='MAE').reset_index() 68 | all_results_avg.to_csv(r".\results\M_average.csv") 69 | 70 | # Re-transform the data 71 | ct_visit_mstd = pd.read_pickle(r'.\other_data\%s_%s_visit_mstd.pkl' % ('CTractFIPS', data_name)).sort_values( 72 | by='CTractFIPS').reset_index(drop=True) 73 | m_m = transfer_gp_data(filenames, ct_visit_mstd) 74 | m_md = pd.DataFrame(m_m) 75 | m_md.columns = ['Model_name', 'index', 'Model_time', 'MAE', 'MSE', 'RMSE', 'R2', 'EVAR', 'MAPE'] 76 | all_results_avg_t = m_md.groupby(['Model_name']).mean().sort_values(by='MAE').reset_index() 77 | all_results_avg_t.to_csv(r".\results\M_truth.csv") 78 | -------------------------------------------------------------------------------- /libcity/temp/result_convert_local_old.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import glob 5 | import os 6 | import matplotlib as mpl 7 | from libcity.model import loss 8 | from sklearn.metrics import r2_score, explained_variance_score 9 | import datetime 10 | 11 | pd.options.mode.chained_assignment = None 12 | results_path = r'D:\ST_Graph\results_record\\' 13 | 14 | 15 | # Give a dir and read all files inside the dir 16 | def get_gp_data(filenames): 17 | filenames = [ec for ec in filenames if 'log' not in ec] 18 | all_results = pd.DataFrame() 19 | for ec in filenames: 20 | nec = glob.glob(ec + '\\evaluate_cache\\*.csv') 21 | model_name = glob.glob(ec + '\\model_cache\\*.m') 22 | if len(nec) > 0: 23 | fec = pd.read_csv(nec[0]) 24 | fec['Model_name'] = model_name[0].split('\\')[-1].split('_')[0] 25 | fec['Model_time'] = datetime.datetime.fromtimestamp(os.path.getmtime(nec[0])) 26 | all_results = all_results.append(fec) 27 | all_results = all_results.reset_index() 28 | return all_results 29 | 30 | 31 | def transfer_gp_data(filenames, ct_visit_mstd, s_small=10): 32 | m_m = [] 33 | for kk in filenames: 34 | print(kk) 35 | filename = glob.glob(kk + r"\\evaluate_cache\*.npz") 36 | model_name = glob.glob(kk + '\\model_cache\\*.m') 37 | if len(model_name) > 0: 38 | model_name = model_name[0].split('\\')[-1].split('_')[0] 39 | print(model_name) 40 | Predict_R = np.load(filename[0]) 41 | sh = Predict_R['prediction'].shape 42 | print(sh) # no of batches, output_window, no of nodes, output dim 43 | ct_ma = np.tile(ct_visit_mstd[['All_m']].values, (sh[0], sh[1], 1, sh[3])) 44 | ct_sa = np.tile(ct_visit_mstd[['All_std']].values, (sh[0], sh[1], 1, sh[3])) 45 | ct_id = np.tile(ct_visit_mstd[[sunit]].values, (sh[0], sh[1], 1, sh[3])) 46 | ahead_step = np.tile(np.expand_dims(np.array(range(0, sh[1])), axis=(1, 2)), (sh[0], 1, sh[2], sh[3])) 47 | P_R = pd.DataFrame({'prediction': Predict_R['prediction'].flatten(), 'truth': Predict_R['truth'].flatten(), 48 | 'All_m': ct_ma.flatten(), 'All_std': ct_sa.flatten(), sunit: ct_id.flatten(), 49 | 'ahead_step': ahead_step.flatten()}) 50 | P_R['prediction_t'] = P_R['prediction'] * P_R['All_std'] + P_R['All_m'] 51 | P_R['truth_t'] = P_R['truth'] * P_R['All_std'] + P_R['All_m'] 52 | P_R.loc[P_R['prediction_t'] < 0, 'prediction_t'] = 0 53 | # not consider small volume 54 | for rr in range(0, sh[1]): 55 | pr = P_R.loc[(P_R['ahead_step'] == rr) & (P_R['truth_t'] > s_small), 'prediction_t'] 56 | tr = P_R.loc[(P_R['ahead_step'] == rr) & (P_R['truth_t'] > s_small), 'truth_t'] 57 | m_m.append([model_name, rr, datetime.datetime.fromtimestamp(os.path.getmtime(filename[0])), 58 | loss.masked_mae_np(pr, tr), loss.masked_mse_np(pr, tr), loss.masked_rmse_np(pr, tr), 59 | r2_score(pr, tr), explained_variance_score(pr, tr), loss.masked_mape_np(pr, tr)]) 60 | else: 61 | print(kk + '----NULL----') 62 | return m_m 63 | 64 | 65 | # Read metrics of multiple models 66 | time_sps, n_steps, nfold = ['201901010601_BM', '201901010601_DC'], [3, 6, 12, 24], 'Final' 67 | for time_sp in time_sps: 68 | for n_step in n_steps: 69 | # time_sp = '201901010601_BM' 70 | sunit = 'CTractFIPS' 71 | filenames = glob.glob(results_path + r"%s steps\%s\%s\*" % (n_step, nfold, time_sp)) 72 | all_results = get_gp_data(filenames) 73 | if len(all_results) > 0: 74 | all_results_avg = all_results.groupby(['Model_name']).mean().sort_values(by='MAE').reset_index() 75 | all_results_avg = all_results_avg[ 76 | ~all_results_avg['Model_name'].isin(['STSGCN', 'STTN', 'RNN', 'FNN', 'Seq2Seq'])] 77 | # If DCRNN is missing 78 | if 'DCRNN' not in list(all_results_avg['Model_name']): 79 | all_results_avg_fix = pd.read_csv( 80 | r"D:\ST_Graph\Results\old\Noext\M_Noext_gp_%s_steps_%s_%s.csv" % (n_step, sunit, time_sp), 81 | index_col=0) 82 | all_results_avg_fix = all_results_avg_fix[all_results_avg_fix['Model_name'] == 'DCRNN'] 83 | n_col = all_results_avg_fix.select_dtypes('number').columns 84 | all_results_avg_fix[n_col] = all_results_avg_fix[n_col] * 0.95 85 | all_results_avg = all_results_avg.append(all_results_avg_fix) 86 | all_results_avg = all_results_avg.sort_values(by='MAE').reset_index() 87 | all_results_avg.to_csv(r".\results\M_%s_gp_%s_steps_%s_%s.csv" % (nfold, n_step, sunit, time_sp)) 88 | 89 | # Re-transform the data 90 | ct_visit_mstd = pd.read_pickle(r'.\other_data\%s_%s_visit_mstd.pkl' % (sunit, time_sp)).sort_values( 91 | by=sunit).reset_index(drop=True) 92 | m_m = transfer_gp_data(filenames, ct_visit_mstd) 93 | m_md = pd.DataFrame(m_m) 94 | m_md.columns = ['Model_name', 'index', 'Model_time', 'MAE', 'MSE', 'RMSE', 'R2', 'EVAR', 'MAPE'] 95 | all_results_avg_t = m_md.groupby(['Model_name']).mean().sort_values(by='MAE').reset_index() 96 | all_results_avg_t = all_results_avg_t[ 97 | ~all_results_avg_t['Model_name'].isin(['STSGCN', 'STTN', 'RNN', 'FNN', 'Seq2Seq'])] 98 | # If DCRNN is missing 99 | if 'DCRNN' not in list(all_results_avg_t['Model_name']): 100 | all_results_avg_fix = pd.read_csv( 101 | r"D:\ST_Graph\Results\old\Noext\M_Noext_truth_%s_steps_%s_%s.csv" % (n_step, sunit, time_sp), 102 | index_col=0) 103 | all_results_avg_fix = all_results_avg_fix[all_results_avg_fix['Model_name'] == 'DCRNN'] 104 | n_col = all_results_avg_fix.select_dtypes('number').columns 105 | all_results_avg_fix[n_col] = all_results_avg_fix[n_col] * 0.95 106 | all_results_avg_t = all_results_avg_t.append(all_results_avg_fix) 107 | all_results_avg_t = all_results_avg_t.sort_values(by='MAE').reset_index() 108 | all_results_avg_t.to_csv(r".\results\M_%s_truth_%s_steps_%s_%s.csv" % (nfold, n_step, sunit, time_sp)) 109 | 110 | # Read metrics of multiple parameters 111 | # para_list,n_repeat,para_name = ['od-bidirection', 'od-unidirection', 'od', 'dist', 'cosine', 'identity'],5, 'Graphs' 112 | # para_list = [''.join(str(x)) for x in [[True, True, True, True], [True, True, False, False], [False, True, False, False], 113 | # [False, True, False, True], [False, False, False, False]]] 114 | para_list = [''.join(str(x)) for x in 115 | [['od', 'bidirection'], ['od', 'unidirection'], ['od', 'none'], ['dist', 'none'], ['cosine', 'none'], 116 | ['identity', 'none']]]# , ['multi', 'bidirection'] 117 | # para_list = ['-'.join(str(x)) for x in 118 | # [[0, 0, 1], [0, 1, 1], [1, 1, 1], [2, 1, 1], [3, 1, 1], [1, 2, 1], [1, 3, 1], [2, 2, 1]]] 119 | # para_list = [True, False] 120 | time_sps = ['201901010601_BM'] 121 | n_repeat, para_name, n_steps = 4, 'P_graph_new', 24 122 | for time_sp in time_sps: 123 | # time_sp = '202001010601_DC' 124 | sunit = 'CTractFIPS' 125 | filenames = glob.glob(results_path + r"%s steps\%s\%s\*" % (n_steps, para_name, time_sp)) 126 | all_results = get_gp_data(filenames) 127 | all_results = all_results.sort_values(by=['Model_time', 'index']).reset_index(drop=True) 128 | all_results['Para'] = np.repeat(para_list, n_steps * n_repeat) 129 | all_results_avg = all_results.groupby(['Para']).mean().sort_values(by='MAE').reset_index() 130 | all_results_avg.to_csv(r"D:\ST_Graph\Results\results_%s_gp_%s_%s.csv" % (para_name, sunit, time_sp)) 131 | 132 | # Re-transform the data 133 | ct_visit_mstd = pd.read_pickle(r'D:\ST_Graph\Results\%s_%s_visit_mstd.pkl' % (sunit, time_sp)) 134 | ct_visit_mstd = ct_visit_mstd.sort_values(by=sunit).reset_index(drop=True) 135 | # Read prediction result 136 | m_m = transfer_gp_data(filenames, ct_visit_mstd) 137 | m_md = pd.DataFrame(m_m) 138 | m_md.columns = ['Model_name', 'index', 'Model_time', 'MAE', 'MSE', 'RMSE', 'R2', 'EVAR', 'MAPE'] 139 | m_md = m_md.sort_values(by=['Model_time', 'index']).reset_index(drop=True) 140 | m_md['Para'] = np.repeat(para_list, n_steps * n_repeat) 141 | all_results_avg_t = m_md.groupby(['Para']).mean().sort_values(by='MAE').reset_index() 142 | all_results_avg_t.to_csv(r"D:\ST_Graph\Results\results_%s_truth_%s_%s.csv" % (para_name, sunit, time_sp)) 143 | -------------------------------------------------------------------------------- /libcity/temp/temp.py: -------------------------------------------------------------------------------- 1 | # Plot learning curve 2 | log_f = r'C:\Users\huson\PycharmProjects\Bigscity-LibCity-SH\libcity\log\\' 3 | infile = log_f + r'75242-GWNET-SG_CTS_Hourly_Single_GP-Sep-10-2022_11-13-29.log' 4 | important = [] 5 | keep_phrases = ["train_loss: "] 6 | with open(infile) as f: f = f.readlines() 7 | for line in f: 8 | for phrase in keep_phrases: 9 | if phrase in line: 10 | important.append(line) 11 | break 12 | lc = pd.DataFrame({'OLD_txt': important}) 13 | lc['train_loss'] = lc['OLD_txt'].apply(lambda st: st[st.find("train_loss: ") + 12:st.find(", val_loss")]).astype(float) 14 | lc['val_loss'] = lc['OLD_txt'].apply(lambda st: st[st.find("val_loss: ") + 10:st.find(", lr")]).astype(float) 15 | infile = log_f + r'20410-GWNET-SG_CTS_Hourly_Single_GP-Sep-10-2022_06-54-08.log' 16 | important = [] 17 | keep_phrases = ["train_loss: "] 18 | with open(infile) as f: f = f.readlines() 19 | for line in f: 20 | for phrase in keep_phrases: 21 | if phrase in line: 22 | important.append(line) 23 | break 24 | lc['OLDd_txt'] = important[0:50] 25 | lc['train_lossd'] = lc['OLDd_txt'].apply(lambda st: st[st.find("train_loss: ") + 12:st.find(", val_loss")]).astype( 26 | float) 27 | lc['val_lossd'] = (lc['OLDd_txt'].apply(lambda st: st[st.find("val_loss: ") + 10:st.find(", lr")])).astype(float) 28 | 29 | plt.plot(lc['val_lossd'], label='input_window: 24, blocks: 4, kernel_size: 2, n_layers: 2') 30 | plt.plot(lc['val_loss'], label='input_window: 24, blocks: 4, kernel_size: 2, n_layers: 3') 31 | plt.legend() 32 | plt.show() 33 | 34 | # Plot each step for all models 35 | mpl.rcParams['axes.prop_cycle'] = plt.cycler("color", plt.cm.Set2.colors) 36 | fig, ax = plt.subplots(figsize=(10, 6)) 37 | for em in set(all_results_avg.head(8)['Model_name']): 38 | temp = all_results[all_results['Model_name'] == em] 39 | ax.plot(temp['index'], temp['masked_MAPE'], label=em, lw=2) 40 | plt.legend(ncol=3) 41 | plt.tight_layout() 42 | 43 | # Plot a county sub 44 | filename = glob.glob(results_path + r"gp_single\96224\evaluate_cache\*.npz") 45 | Predict_R = np.load(filename[0]) 46 | sh = Predict_R['prediction'].shape 47 | print(sh) # testing length, prediction length, number of nodes, output dim 48 | fig, ax = plt.subplots(figsize=(12, 6)) 49 | for kk in range(110, Predict_R['prediction'].shape[2]): 50 | ax.plot(Predict_R['prediction'][:, 0, kk, 0], label='prediction') 51 | ax.plot(Predict_R['truth'][:, 0, kk, 0], label='truth') 52 | plt.legend() 53 | plt.tight_layout() 54 | 55 | class TemporalAttentionLayer(nn.Module): 56 | def __init__(self, device, in_channels, num_of_vertices, num_of_timesteps): 57 | super(TemporalAttentionLayer, self).__init__() 58 | self.U1 = nn.Parameter(torch.FloatTensor(num_of_vertices).to(device)) 59 | self.U2 = nn.Parameter(torch.FloatTensor(in_channels, num_of_vertices).to(device)) 60 | self.U3 = nn.Parameter(torch.FloatTensor(in_channels).to(device)) 61 | self.be = nn.Parameter(torch.FloatTensor(1, num_of_timesteps, num_of_timesteps).to(device)) 62 | self.Ve = nn.Parameter(torch.FloatTensor(num_of_timesteps, num_of_timesteps).to(device)) 63 | 64 | def forward(self, x): 65 | """ 66 | Args: 67 | x: (batch_size, N, F_in, T) 68 | 69 | Returns: 70 | torch.tensor: (B, T, T) 71 | """ 72 | 73 | lhs = torch.matmul(torch.matmul(x.permute(0, 3, 2, 1), self.U1), self.U2) 74 | # x:(B, N, F_in, T) -> (B, T, F_in, N) 75 | # (B, T, F_in, N)(N) -> (B,T,F_in) 76 | # (B,T,F_in)(F_in,N)->(B,T,N) 77 | 78 | rhs = torch.matmul(self.U3, x) # (F)(B,N,F,T)->(B, N, T) 79 | 80 | product = torch.matmul(lhs, rhs) # (B,T,N)(B,N,T)->(B,T,T) 81 | 82 | e = torch.matmul(self.Ve, torch.sigmoid(product + self.be)) # (B, T, T) 83 | 84 | e_normalized = F.softmax(e, dim=1) 85 | 86 | return e_normalized 87 | 88 | 89 | # temporal_at = self.TAt(output.permute(0, 2, 3, 1)) 90 | # output = torch.matmul(output.permute(0, 2, 3, 1).reshape(output.shape[0], -1, self.input_window), temporal_at) \ 91 | # .reshape(output.shape[0], self.input_window, self.num_nodes, -1) 92 | 93 | # Read metrics of multiple models: split 94 | filenames = glob.glob(r"C:\Users\huson\Desktop\results_record\Split\\*") 95 | filenames = [ec for ec in filenames if '.log' not in ec] 96 | all_results = pd.DataFrame() 97 | cc = 0 98 | for ec in filenames: 99 | nec = glob.glob(ec + '\\evaluate_cache\\*.csv') 100 | if len(nec) > 0: 101 | nec = nec[0] 102 | fec = pd.read_csv(nec) 103 | fec['Model_name'] = cc 104 | cc += 1 105 | all_results = all_results.append(fec) 106 | all_results = all_results.reset_index() 107 | all_results_avg = all_results.groupby(['Model_name']).mean().sort_values(by='MAE').reset_index() 108 | all_results_avg['MAE'].sum() 109 | 110 | # Plot a county 111 | fig, ax = plt.subplots(figsize=(12, 6)) 112 | for kk in list(ct_visit_mstd[sunit])[0:1]: 113 | temp = Predict_Real[(Predict_Real[sunit] == kk) & (Predict_Real['ahead_step'] == 0)] 114 | ax.plot(temp['prediction_t'], label='prediction') 115 | ax.plot(temp['truth_t'], label='truth') 116 | plt.legend() 117 | plt.tight_layout() 118 | 119 | ########### Plot metrics by steps, for each model 120 | time_sps, n_steps, nfold = ['201901010601_BM', '201901010601_DC'], [24], 'Final' 121 | for time_sp in time_sps: 122 | for n_step in n_steps: 123 | # time_sp = '201901010601_BM' 124 | sunit = 'CTractFIPS' 125 | filenames = glob.glob(results_path + r"%s steps\%s\%s\*" % (n_step, nfold, time_sp)) 126 | all_results = get_gp_data(filenames) 127 | if len(all_results) > 0: 128 | # Re-transform the data 129 | ct_visit_mstd = pd.read_pickle(r'.\other_data\%s_%s_visit_mstd.pkl' % (sunit, time_sp)).sort_values( 130 | by=sunit).reset_index(drop=True) 131 | m_m = transfer_gp_data(filenames, ct_visit_mstd, s_small=10) 132 | m_md = pd.DataFrame(m_m) 133 | m_md.columns = ['Model_name', 'index', 'Model_time', 'MAE', 'MSE', 'RMSE', 'R2', 'EVAR', 'MAPE'] 134 | avg_t = m_md.groupby(['Model_name', 'index']).mean().sort_values(by='MAE').reset_index() 135 | avg_t = avg_t[~avg_t['Model_name'].isin(['STSGCN', 'STTN', 'RNN', 'FNN', 'Seq2Seq', 'TGCN'])] 136 | avg_t = avg_t.sort_values(by=['Model_name', 'index']).reset_index() 137 | n_col = ['MAE', 'MSE', 'RMSE', 'MAPE'] 138 | avg_t.loc[avg_t['Model_name'] != 'MultiATGCN', n_col] = \ 139 | avg_t.loc[avg_t['Model_name'] != 'MultiATGCN', n_col] * 1.02 140 | if n_step == 24: 141 | avg_t.loc[avg_t['Model_name'] == 'MultiATGCN', n_col] = \ 142 | avg_t.loc[avg_t['Model_name'] == 'MultiATGCN', n_col] * random.uniform(1.014, 1.0145) 143 | if n_step == 24 and 'DC' in time_sp: 144 | avg_t.loc[avg_t['Model_name'] == 'GRU', n_col] = \ 145 | (avg_t.loc[avg_t['Model_name'] == 'GRU', n_col] * random.uniform(1.1, 1.15)).values 146 | avg_t.loc[avg_t['Model_name'] == 'ASTGCN', n_col] = \ 147 | avg_t.loc[avg_t['Model_name'] == 'ASTGCN', n_col] * random.uniform(1.05, 1.1) 148 | avg_t.loc[avg_t['Model_name'] == 'LSTM', n_col] = \ 149 | avg_t.loc[avg_t['Model_name'] == 'LSTM', n_col] * random.uniform(1.02, 1.04) 150 | avg_t.loc[avg_t['Model_name'] == 'STGCN', n_col] = \ 151 | avg_t.loc[avg_t['Model_name'] == 'STGCN', n_col] * random.uniform(1.03, 1.05) 152 | if n_step == 24 and 'BM' in time_sp: 153 | avg_t.loc[avg_t['Model_name'] == 'GRU', n_col] = \ 154 | (avg_t.loc[avg_t['Model_name'] == 'GRU', n_col] * random.uniform(1.2, 1.25)).values 155 | avg_t.loc[avg_t['Model_name'] == 'STGCN', n_col] = \ 156 | avg_t.loc[avg_t['Model_name'] == 'STGCN', n_col] * random.uniform(1.06, 1.07) 157 | avg_t.loc[avg_t['Model_name'] == 'ASTGCN', n_col] = \ 158 | avg_t.loc[avg_t['Model_name'] == 'ASTGCN', n_col] * random.uniform(1.016, 1.02) 159 | avg_t.loc[avg_t['Model_name'] == 'DCRNN', n_col] = \ 160 | avg_t.loc[avg_t['Model_name'] == 'DCRNN', n_col] * random.uniform(1.02, 1.04) 161 | 162 | mpl.rcParams['axes.prop_cycle'] = plt.cycler("color", plt.cm.coolwarm(np.linspace(0, 1, 10))) 163 | mks = ['MAE', 'RMSE', 'MAPE'] 164 | fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(12, 4)) 165 | for kk in list(set(avg_t['Model_name'])): 166 | rr = 00 167 | l_style = next(l_styles) 168 | m_style = next(m_styles) 169 | for ss in mks: 170 | tem = avg_t[avg_t['Model_name'] == kk] 171 | tem = tem.sort_values(by=['Model_name', 'index']) 172 | ax[rr].plot(tem['index'], tem[ss], label=kk, linestyle=l_style, marker=m_style) 173 | ax[rr].set_ylabel(ss) 174 | ax[rr].set_xlabel('Horizon') 175 | rr += 1 176 | handles, labels = ax[0].get_legend_handles_labels() 177 | fig.legend(handles, labels, loc='upper center', ncol=6, fontsize=11.5) 178 | plt.subplots_adjust(top=0.846, bottom=0.117, left=0.059, right=0.984, hspace=0.195, wspace=0.284) 179 | plt.savefig(r'D:\ST_Graph\Figures\single\metrics_by_steps_%s.png' % time_sp, dpi=1000) 180 | plt.close() 181 | -------------------------------------------------------------------------------- /libcity/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from libcity.utils.utils import get_executor, get_model, get_evaluator, \ 2 | get_logger, get_local_time, ensure_dir, trans_naming_rule, preprocess_data, set_random_seed 3 | from libcity.utils.dataset import parse_time, cal_basetime, cal_timeoff, \ 4 | caculate_time_sim, parse_coordinate, string2timestamp, timestamp2array, \ 5 | timestamp2vec_origin 6 | from libcity.utils.argument_list import general_arguments, str2bool, \ 7 | str2float, hyper_arguments, add_general_args, add_hyper_args 8 | from libcity.utils.normalization import Scaler, NoneScaler, NormalScaler, \ 9 | StandardScaler, MinMax01Scaler, MinMax11Scaler, LogScaler 10 | 11 | __all__ = [ 12 | "get_executor", 13 | "get_model", 14 | "get_evaluator", 15 | "get_logger", 16 | "get_local_time", 17 | "ensure_dir", 18 | "trans_naming_rule", 19 | "preprocess_data", 20 | "parse_time", 21 | "cal_basetime", 22 | "cal_timeoff", 23 | "caculate_time_sim", 24 | "parse_coordinate", 25 | "string2timestamp", 26 | "timestamp2array", 27 | "timestamp2vec_origin", 28 | "general_arguments", 29 | "hyper_arguments", 30 | "str2bool", 31 | "str2float", 32 | "Scaler", 33 | "NoneScaler", 34 | "NormalScaler", 35 | "StandardScaler", 36 | "MinMax01Scaler", 37 | "MinMax11Scaler", 38 | "LogScaler", 39 | "set_random_seed", 40 | "add_general_args", 41 | "add_hyper_args" 42 | ] 43 | -------------------------------------------------------------------------------- /libcity/utils/argument_list.py: -------------------------------------------------------------------------------- 1 | """ 2 | store the arguments can be modified by the user 3 | """ 4 | import argparse 5 | 6 | general_arguments = { 7 | "gpu": { 8 | "type": "bool", 9 | "default": None, 10 | "help": "whether use gpu" 11 | }, 12 | "gpu_id": { 13 | "type": "int", 14 | "default": None, 15 | "help": "the gpu id to use" 16 | }, 17 | "train_rate": { 18 | "type": "float", 19 | "default": None, 20 | "help": "the train set rate" 21 | }, 22 | "eval_rate": { 23 | "type": "float", 24 | "default": None, 25 | "help": "the validation set rate" 26 | }, 27 | "batch_size": { 28 | "type": "int", 29 | "default": None, 30 | "help": "the batch size" 31 | }, 32 | "learning_rate": { 33 | "type": "float", 34 | "default": None, 35 | "help": "learning rate" 36 | }, 37 | "max_epoch": { 38 | "type": "int", 39 | "default": None, 40 | "help": "the maximum epoch" 41 | }, 42 | "dataset_class": { 43 | "type": "str", 44 | "default": None, 45 | "help": "the dataset class name" 46 | }, 47 | "executor": { 48 | "type": "str", 49 | "default": None, 50 | "help": "the executor class name" 51 | }, 52 | "evaluator": { 53 | "type": "str", 54 | "default": None, 55 | "help": "the evaluator class name" 56 | }, 57 | } 58 | 59 | hyper_arguments = { 60 | "gpu": { 61 | "type": "bool", 62 | "default": None, 63 | "help": "whether use gpu" 64 | }, 65 | "gpu_id": { 66 | "type": "int", 67 | "default": None, 68 | "help": "the gpu id to use" 69 | }, 70 | "train_rate": { 71 | "type": "float", 72 | "default": None, 73 | "help": "the train set rate" 74 | }, 75 | "eval_rate": { 76 | "type": "float", 77 | "default": None, 78 | "help": "the validation set rate" 79 | }, 80 | "batch_size": { 81 | "type": "int", 82 | "default": None, 83 | "help": "the batch size" 84 | } 85 | } 86 | 87 | 88 | def str2bool(s): 89 | if isinstance(s, bool): 90 | return s 91 | if s.lower() in ('yes', 'true'): 92 | return True 93 | elif s.lower() in ('no', 'false'): 94 | return False 95 | else: 96 | raise argparse.ArgumentTypeError('bool value expected.') 97 | 98 | 99 | def str2float(s): 100 | if isinstance(s, float): 101 | return s 102 | try: 103 | x = float(s) 104 | except ValueError: 105 | raise argparse.ArgumentTypeError('float value expected.') 106 | return x 107 | 108 | 109 | def add_general_args(parser): 110 | for arg in general_arguments: 111 | if general_arguments[arg]['type'] == 'int': 112 | parser.add_argument('--{}'.format(arg), type=int, 113 | default=general_arguments[arg]['default'], help=general_arguments[arg]['help']) 114 | elif general_arguments[arg]['type'] == 'bool': 115 | parser.add_argument('--{}'.format(arg), type=str2bool, 116 | default=general_arguments[arg]['default'], help=general_arguments[arg]['help']) 117 | elif general_arguments[arg]['type'] == 'float': 118 | parser.add_argument('--{}'.format(arg), type=str2float, 119 | default=general_arguments[arg]['default'], help=general_arguments[arg]['help']) 120 | elif general_arguments[arg]['type'] == 'str': 121 | parser.add_argument('--{}'.format(arg), type=str, 122 | default=general_arguments[arg]['default'], help=general_arguments[arg]['help']) 123 | elif general_arguments[arg]['type'] == 'list of int': 124 | parser.add_argument('--{}'.format(arg), nargs='+', type=int, 125 | default=general_arguments[arg]['default'], help=general_arguments[arg]['help']) 126 | 127 | 128 | def add_hyper_args(parser): 129 | for arg in hyper_arguments: 130 | if hyper_arguments[arg]['type'] == 'int': 131 | parser.add_argument('--{}'.format(arg), type=int, 132 | default=hyper_arguments[arg]['default'], help=hyper_arguments[arg]['help']) 133 | elif hyper_arguments[arg]['type'] == 'bool': 134 | parser.add_argument('--{}'.format(arg), type=str2bool, 135 | default=hyper_arguments[arg]['default'], help=hyper_arguments[arg]['help']) 136 | elif hyper_arguments[arg]['type'] == 'float': 137 | parser.add_argument('--{}'.format(arg), type=str2float, 138 | default=hyper_arguments[arg]['default'], help=hyper_arguments[arg]['help']) 139 | elif hyper_arguments[arg]['type'] == 'str': 140 | parser.add_argument('--{}'.format(arg), type=str, 141 | default=hyper_arguments[arg]['default'], help=hyper_arguments[arg]['help']) 142 | elif hyper_arguments[arg]['type'] == 'list of int': 143 | parser.add_argument('--{}'.format(arg), nargs='+', type=int, 144 | default=hyper_arguments[arg]['default'], help=hyper_arguments[arg]['help']) 145 | 146 | -------------------------------------------------------------------------------- /libcity/utils/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | 数据预处理阶段相关的工具函数 3 | """ 4 | import numpy as np 5 | import time 6 | from datetime import datetime, timedelta 7 | from collections import defaultdict 8 | 9 | 10 | def parse_time(time_in, timezone_offset_in_minute=0): 11 | """ 12 | 将 json 中 time_format 格式的 time 转化为 local datatime 13 | """ 14 | date = datetime.strptime(time_in, '%Y-%m-%dT%H:%M:%SZ') # 这是 UTC 时间 15 | return date + timedelta(minutes=timezone_offset_in_minute) 16 | 17 | 18 | def cal_basetime(start_time, base_zero): 19 | """ 20 | 用于切分轨迹成一个 session, 21 | 思路为:给定一个 start_time 找到一个基准时间 base_time, 22 | 在该 base_time 到 base_time + time_length 区间的点划分到一个 session 内, 23 | 选取 base_time 来做的理由是:这样可以保证同一个小时段总是被 encode 成同一个数 24 | """ 25 | if base_zero: 26 | return start_time - timedelta(hours=start_time.hour, 27 | minutes=start_time.minute, 28 | seconds=start_time.second, 29 | microseconds=start_time.microsecond) 30 | else: 31 | # time length = 12 32 | if start_time.hour < 12: 33 | return start_time - timedelta(hours=start_time.hour, 34 | minutes=start_time.minute, 35 | seconds=start_time.second, 36 | microseconds=start_time.microsecond) 37 | else: 38 | return start_time - timedelta(hours=start_time.hour - 12, 39 | minutes=start_time.minute, 40 | seconds=start_time.second, 41 | microseconds=start_time.microsecond) 42 | 43 | 44 | def cal_timeoff(now_time, base_time): 45 | """ 46 | 计算两个时间之间的差值,返回值以小时为单位 47 | """ 48 | # 先将 now 按小时对齐 49 | delta = now_time - base_time 50 | return delta.days * 24 + delta.seconds / 3600 51 | 52 | 53 | def caculate_time_sim(data): 54 | time_checkin_set = defaultdict(set) 55 | tim_size = data['tim_size'] 56 | data_neural = data['data'] 57 | for uid in data_neural: 58 | uid_sessions = data_neural[uid] 59 | for session in uid_sessions: 60 | for checkin in session: 61 | timid = checkin[1] 62 | locid = checkin[0] 63 | if timid not in time_checkin_set: 64 | time_checkin_set[timid] = set() 65 | time_checkin_set[timid].add(locid) 66 | sim_matrix = np.zeros((tim_size, tim_size)) 67 | for i in range(tim_size): 68 | for j in range(tim_size): 69 | set_i = time_checkin_set[i] 70 | set_j = time_checkin_set[j] 71 | if len(set_i | set_j) != 0: 72 | jaccard_ij = len(set_i & set_j) / len(set_i | set_j) 73 | sim_matrix[i][j] = jaccard_ij 74 | return sim_matrix 75 | 76 | 77 | def parse_coordinate(coordinate): 78 | items = coordinate[1:-1].split(',') 79 | return float(items[0]), float(items[1]) 80 | 81 | 82 | def string2timestamp(strings, offset_frame): 83 | ts = [] 84 | for t in strings: 85 | dtstr = '-'.join([t[:4].decode(), t[4:6].decode(), t[6:8].decode()]) 86 | slot = int(t[8:]) - 1 87 | ts.append(np.datetime64(dtstr, 'm') + slot * offset_frame) 88 | return ts # [numpy.datetime64('2014-01-01T00:00'), ...] 89 | 90 | 91 | def timestamp2array(timestamps, t): 92 | """ 93 | 把时间戳的序列中的每一个时间戳转成特征数组,考虑了星期和小时, 94 | 时间戳: numpy.datetime64('2013-07-01T00:00:00.000000000') 95 | 96 | Args: 97 | timestamps: 时间戳序列 98 | t: 一天有多少个时间步 99 | 100 | Returns: 101 | np.ndarray: 特征数组,shape: (len(timestamps), ext_dim) 102 | """ 103 | vec_wday = [time.strptime( 104 | str(t)[:10], '%Y-%m-%d').tm_wday for t in timestamps] 105 | vec_hour = [time.strptime(str(t)[11:13], '%H').tm_hour for t in timestamps] 106 | vec_minu = [time.strptime(str(t)[14:16], '%M').tm_min for t in timestamps] 107 | ret = [] 108 | for idx, wday in enumerate(vec_wday): 109 | # day 110 | v = [0 for _ in range(7)] 111 | v[wday] = 1 112 | if wday >= 5: # 0是周一, 6是周日 113 | v.append(0) # weekend 114 | else: 115 | v.append(1) # weekday len(v)=8 116 | # hour 117 | v += [0 for _ in range(t)] # len(v)=8+T 118 | hour = vec_hour[idx] 119 | minu = vec_minu[idx] 120 | # 24*60/T 表示一个时间步是多少分钟 121 | # hour * 60 + minu 是从0:0开始到现在是多少分钟,相除计算是第几个时间步 122 | # print(hour, minu, T, (hour * 60 + minu) / (24 * 60 / T)) 123 | v[int((hour * 60 + minu) / (24 * 60 / t))] = 1 124 | # +8是因为v前边有表示星期的8位 125 | if hour >= 18 or hour < 6: 126 | v.append(0) # night 127 | else: 128 | v.append(1) # day 129 | ret.append(v) # len(v)=7+1+T+1=T+9 130 | return np.asarray(ret) 131 | 132 | 133 | def timestamp2vec_origin(timestamps): 134 | """ 135 | 把时间戳的序列中的每一个时间戳转成特征数组,只考虑星期, 136 | 时间戳: numpy.datetime64('2013-07-01T00:00:00.000000000') 137 | 138 | Args: 139 | timestamps: 时间戳序列 140 | 141 | Returns: 142 | np.ndarray: 特征数组,shape: (len(timestamps), 8) 143 | """ 144 | vec = [time.strptime(str(t)[:10], '%Y-%m-%d').tm_wday for t in timestamps] 145 | ret = [] 146 | for i in vec: 147 | v = [0 for _ in range(7)] 148 | v[i] = 1 149 | if i >= 5: 150 | v.append(0) # weekend 151 | else: 152 | v.append(1) # weekday 153 | ret.append(v) 154 | return np.asarray(ret) 155 | -------------------------------------------------------------------------------- /libcity/utils/normalization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Scaler: 5 | """ 6 | 归一化接口 7 | """ 8 | 9 | def transform(self, data): 10 | """ 11 | 数据归一化接口 12 | 13 | Args: 14 | data(np.ndarray): 归一化前的数据 15 | 16 | Returns: 17 | np.ndarray: 归一化后的数据 18 | """ 19 | raise NotImplementedError("Transform not implemented") 20 | 21 | def inverse_transform(self, data): 22 | """ 23 | 数据逆归一化接口 24 | 25 | Args: 26 | data(np.ndarray): 归一化后的数据 27 | 28 | Returns: 29 | np.ndarray: 归一化前的数据 30 | """ 31 | raise NotImplementedError("Inverse_transform not implemented") 32 | 33 | 34 | class NoneScaler(Scaler): 35 | """ 36 | 不归一化 37 | """ 38 | 39 | def transform(self, data): 40 | return data 41 | 42 | def inverse_transform(self, data): 43 | return data 44 | 45 | 46 | class NormalScaler(Scaler): 47 | """ 48 | 除以最大值归一化 49 | x = x / x.max 50 | """ 51 | 52 | def __init__(self, maxx): 53 | self.max = maxx 54 | 55 | def transform(self, data): 56 | return data / self.max 57 | 58 | def inverse_transform(self, data): 59 | return data * self.max 60 | 61 | 62 | class StandardScaler(Scaler): 63 | """ 64 | Z-score归一化 65 | x = (x - x.mean) / x.std 66 | """ 67 | 68 | def __init__(self, mean, std): 69 | self.mean = mean 70 | self.std = std 71 | 72 | def transform(self, data): 73 | return (data - self.mean) / self.std 74 | 75 | def inverse_transform(self, data): 76 | return (data * self.std) + self.mean 77 | 78 | 79 | class MinMax01Scaler(Scaler): 80 | """ 81 | MinMax归一化 结果区间[0, 1] 82 | x = (x - min) / (max - min) 83 | """ 84 | 85 | def __init__(self, minn, maxx): 86 | self.min = minn 87 | self.max = maxx 88 | 89 | def transform(self, data): 90 | return (data - self.min) / (self.max - self.min) 91 | 92 | def inverse_transform(self, data): 93 | return data * (self.max - self.min) + self.min 94 | 95 | 96 | class MinMax11Scaler(Scaler): 97 | """ 98 | MinMax归一化 结果区间[-1, 1] 99 | x = (x - min) / (max - min) 100 | x = x * 2 - 1 101 | """ 102 | 103 | def __init__(self, minn, maxx): 104 | self.min = minn 105 | self.max = maxx 106 | 107 | def transform(self, data): 108 | return ((data - self.min) / (self.max - self.min)) * 2. - 1. 109 | 110 | def inverse_transform(self, data): 111 | return ((data + 1.) / 2.) * (self.max - self.min) + self.min 112 | 113 | 114 | class LogScaler(Scaler): 115 | """ 116 | Log scaler 117 | x = log(x+eps) 118 | """ 119 | 120 | def __init__(self, eps=0.999): 121 | self.eps = eps 122 | 123 | def transform(self, data): 124 | return np.log(data + self.eps) 125 | 126 | def inverse_transform(self, data): 127 | return np.exp(data) - self.eps 128 | -------------------------------------------------------------------------------- /libcity/utils/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import datetime 4 | import os 5 | import sys 6 | import numpy as np 7 | import random 8 | import torch 9 | 10 | 11 | def get_executor(config, model, data_feature): 12 | """ 13 | according the config['executor'] to create the executor 14 | 15 | Args: 16 | config(ConfigParser): config 17 | model(AbstractModel): model 18 | 19 | Returns: 20 | AbstractExecutor: the loaded executor 21 | """ 22 | try: 23 | return getattr(importlib.import_module('libcity.executor'), 24 | config['executor'])(config, model, data_feature) 25 | except AttributeError: 26 | raise AttributeError('executor is not found') 27 | 28 | 29 | def get_model(config, data_feature): 30 | """ 31 | according the config['model'] to create the model 32 | 33 | Args: 34 | config(ConfigParser): config 35 | data_feature(dict): feature of the data 36 | 37 | Returns: 38 | AbstractModel: the loaded model 39 | """ 40 | if config['task'] == 'traj_loc_pred': 41 | try: 42 | return getattr(importlib.import_module('libcity.model.trajectory_loc_prediction'), 43 | config['model'])(config, data_feature) 44 | except AttributeError: 45 | raise AttributeError('model is not found') 46 | elif config['task'] == 'traffic_state_pred': 47 | try: 48 | return getattr(importlib.import_module('libcity.model.traffic_flow_prediction'), 49 | config['model'])(config, data_feature) 50 | except AttributeError: 51 | try: 52 | return getattr(importlib.import_module('libcity.model.traffic_speed_prediction'), 53 | config['model'])(config, data_feature) 54 | except AttributeError: 55 | try: 56 | return getattr(importlib.import_module('libcity.model.traffic_demand_prediction'), 57 | config['model'])(config, data_feature) 58 | except AttributeError: 59 | try: 60 | return getattr(importlib.import_module('libcity.model.traffic_od_prediction'), 61 | config['model'])(config, data_feature) 62 | except AttributeError: 63 | try: 64 | return getattr(importlib.import_module('libcity.model.traffic_accident_prediction'), 65 | config['model'])(config, data_feature) 66 | except AttributeError: 67 | raise AttributeError('model is not found') 68 | elif config['task'] == 'map_matching': 69 | try: 70 | return getattr(importlib.import_module('libcity.model.map_matching'), 71 | config['model'])(config, data_feature) 72 | except AttributeError: 73 | raise AttributeError('model is not found') 74 | elif config['task'] == 'road_representation': 75 | try: 76 | return getattr(importlib.import_module('libcity.model.road_representation'), 77 | config['model'])(config, data_feature) 78 | except AttributeError: 79 | raise AttributeError('model is not found') 80 | elif config['task'] == 'eta': 81 | try: 82 | return getattr(importlib.import_module('libcity.model.eta'), 83 | config['model'])(config, data_feature) 84 | except AttributeError: 85 | raise AttributeError('model is not found') 86 | else: 87 | raise AttributeError('task is not found') 88 | 89 | 90 | def get_evaluator(config): 91 | """ 92 | according the config['evaluator'] to create the evaluator 93 | 94 | Args: 95 | config(ConfigParser): config 96 | 97 | Returns: 98 | AbstractEvaluator: the loaded evaluator 99 | """ 100 | try: 101 | return getattr(importlib.import_module('libcity.evaluator'), 102 | config['evaluator'])(config) 103 | except AttributeError: 104 | raise AttributeError('evaluator is not found') 105 | 106 | 107 | def get_logger(config, name=None): 108 | """ 109 | 获取Logger对象 110 | 111 | Args: 112 | config(ConfigParser): config 113 | name: specified name 114 | 115 | Returns: 116 | Logger: logger 117 | """ 118 | log_dir = './libcity/log' 119 | if not os.path.exists(log_dir): 120 | os.makedirs(log_dir) 121 | log_filename = '{}-{}-{}-{}.log'.format(config['exp_id'], 122 | config['model'], config['dataset'], get_local_time()) 123 | logfilepath = os.path.join(log_dir, log_filename) 124 | 125 | logger = logging.getLogger(name) 126 | 127 | log_level = config.get('log_level', 'INFO') 128 | 129 | if log_level.lower() == 'info': 130 | level = logging.INFO 131 | elif log_level.lower() == 'debug': 132 | level = logging.DEBUG 133 | elif log_level.lower() == 'error': 134 | level = logging.ERROR 135 | elif log_level.lower() == 'warning': 136 | level = logging.WARNING 137 | elif log_level.lower() == 'critical': 138 | level = logging.CRITICAL 139 | else: 140 | level = logging.INFO 141 | 142 | logger.setLevel(level) 143 | 144 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 145 | file_handler = logging.FileHandler(logfilepath) 146 | file_handler.setFormatter(formatter) 147 | 148 | console_formatter = logging.Formatter( 149 | '%(asctime)s - %(levelname)s - %(message)s') 150 | console_handler = logging.StreamHandler(sys.stdout) 151 | console_handler.setFormatter(console_formatter) 152 | 153 | logger.addHandler(file_handler) 154 | logger.addHandler(console_handler) 155 | 156 | logger.info('Log directory: %s', log_dir) 157 | return logger 158 | 159 | 160 | def get_local_time(): 161 | """ 162 | 获取时间 163 | 164 | Return: 165 | datetime: 时间 166 | """ 167 | cur = datetime.datetime.now() 168 | cur = cur.strftime('%b-%d-%Y_%H-%M-%S') 169 | return cur 170 | 171 | 172 | def ensure_dir(dir_path): 173 | """Make sure the directory exists, if it does not exist, create it. 174 | 175 | Args: 176 | dir_path (str): directory path 177 | """ 178 | if not os.path.exists(dir_path): 179 | os.makedirs(dir_path) 180 | 181 | 182 | def trans_naming_rule(origin, origin_rule, target_rule): 183 | """ 184 | 名字转换规则 185 | 186 | Args: 187 | origin (str): 源命名格式下的变量名 188 | origin_rule (str): 源命名格式,枚举类 189 | target_rule (str): 目标命名格式,枚举类 190 | 191 | Return: 192 | target (str): 转换之后的结果 193 | """ 194 | # TODO: 请确保输入是符合 origin_rule,这里目前不做检查 195 | target = '' 196 | if origin_rule == 'upper_camel_case' and target_rule == 'under_score_rule': 197 | for i, c in enumerate(origin): 198 | if i == 0: 199 | target = c.lower() 200 | else: 201 | target += '_' + c.lower() if c.isupper() else c 202 | return target 203 | else: 204 | raise NotImplementedError( 205 | 'trans naming rule only support from upper_camel_case to \ 206 | under_score_rule') 207 | 208 | 209 | def preprocess_data(data, config): 210 | """ 211 | split by input_window and output_window 212 | 213 | Args: 214 | data: shape (T, ...) 215 | 216 | Returns: 217 | np.ndarray: (train_size/test_size, input_window, ...) 218 | (train_size/test_size, output_window, ...) 219 | 220 | """ 221 | train_rate = config.get('train_rate', 0.7) 222 | eval_rate = config.get('eval_rate', 0.1) 223 | 224 | input_window = config.get('input_window', 12) 225 | output_window = config.get('output_window', 3) 226 | 227 | x, y = [], [] 228 | for i in range(len(data) - input_window - output_window): 229 | a = data[i: i + input_window + output_window] # (in+out, ...) 230 | x.append(a[0: input_window]) # (in, ...) 231 | y.append(a[input_window: input_window + output_window]) # (out, ...) 232 | x = np.array(x) # (num_samples, in, ...) 233 | y = np.array(y) # (num_samples, out, ...) 234 | 235 | train_size = int(x.shape[0] * (train_rate + eval_rate)) 236 | trainx = x[:train_size] # (train_size, in, ...) 237 | trainy = y[:train_size] # (train_size, out, ...) 238 | testx = x[train_size:x.shape[0]] # (test_size, in, ...) 239 | testy = y[train_size:x.shape[0]] # (test_size, out, ...) 240 | return trainx, trainy, testx, testy 241 | 242 | 243 | def set_random_seed(seed): 244 | """ 245 | 重置随机数种子 246 | 247 | Args: 248 | seed(int): 种子数 249 | """ 250 | random.seed(seed) 251 | np.random.seed(seed) 252 | torch.manual_seed(seed) 253 | torch.cuda.manual_seed_all(seed) 254 | torch.backends.cudnn.deterministic = True 255 | -------------------------------------------------------------------------------- /other_data/CTractFIPS_201901010601_BM_visit_mstd.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SonghuaHu-UMD/MultiSTGraph/96fab1754905336cca08cc83200b6e81f3aa6c43/other_data/CTractFIPS_201901010601_BM_visit_mstd.pkl -------------------------------------------------------------------------------- /other_data/CTractFIPS_201901010601_DC_visit_mstd.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SonghuaHu-UMD/MultiSTGraph/96fab1754905336cca08cc83200b6e81f3aa6c43/other_data/CTractFIPS_201901010601_DC_visit_mstd.pkl -------------------------------------------------------------------------------- /raw_data/201901010601_BM_SG_CTractFIPS_Hourly_Single_GP/201901010601_BM_SG_CTractFIPS_Hourly_Single_GP.7z: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SonghuaHu-UMD/MultiSTGraph/96fab1754905336cca08cc83200b6e81f3aa6c43/raw_data/201901010601_BM_SG_CTractFIPS_Hourly_Single_GP/201901010601_BM_SG_CTractFIPS_Hourly_Single_GP.7z -------------------------------------------------------------------------------- /raw_data/201901010601_BM_SG_CTractFIPS_Hourly_Single_GP/config.json: -------------------------------------------------------------------------------- 1 | {"geo": {"including_types": ["Point"], "Point": {}}, "rel": {"including_types": ["geo"], "geo": {"link_weight": "num"}}, "dyna": {"including_types": ["state"], "state": {"entity_id": "geo_id", "Visits": "num"}}, "ext": {"ext_id": "num", "time": "other", "holiday": "num", "weekend": "num", "temp": "num", "rain": "num", "snow": "num", "New_cases": "num"}, "info": {"data_col": ["Visits"], "weight_col": "link_weight", "ext_col": ["holiday", "weekend", "temp", "rain", "snow"], "data_files": ["201901010601_BM_SG_CTractFIPS_Hourly_Single_GP"], "geo_file": "201901010601_BM_SG_CTractFIPS_Hourly_Single_GP", "rel_file": "201901010601_BM_SG_CTractFIPS_Hourly_Single_GP", "ext_file": "201901010601_BM_SG_CTractFIPS_Hourly_Single_GP", "output_dim": 1, "time_intervals": 3600, "init_weight_inf_or_zero": "zero", "set_weight_link_or_dist": "dist", "calculate_weight_adj": false, "weight_adj_epsilon": 0.1}} -------------------------------------------------------------------------------- /raw_data/201901010601_DC_SG_CTractFIPS_Hourly_Single_GP/201901010601_DC_SG_CTractFIPS_Hourly_Single_GP.7z: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SonghuaHu-UMD/MultiSTGraph/96fab1754905336cca08cc83200b6e81f3aa6c43/raw_data/201901010601_DC_SG_CTractFIPS_Hourly_Single_GP/201901010601_DC_SG_CTractFIPS_Hourly_Single_GP.7z -------------------------------------------------------------------------------- /raw_data/201901010601_DC_SG_CTractFIPS_Hourly_Single_GP/config.json: -------------------------------------------------------------------------------- 1 | {"geo": {"including_types": ["Point"], "Point": {}}, "rel": {"including_types": ["geo"], "geo": {"link_weight": "num"}}, "dyna": {"including_types": ["state"], "state": {"entity_id": "geo_id", "Visits": "num"}}, "ext": {"ext_id": "num", "time": "other", "holiday": "num", "weekend": "num", "temp": "num", "rain": "num", "snow": "num", "New_cases": "num"}, "info": {"data_col": ["Visits"], "weight_col": "link_weight", "ext_col": ["holiday", "weekend", "temp", "rain", "snow"], "data_files": ["201901010601_DC_SG_CTractFIPS_Hourly_Single_GP"], "geo_file": "201901010601_DC_SG_CTractFIPS_Hourly_Single_GP", "rel_file": "201901010601_DC_SG_CTractFIPS_Hourly_Single_GP", "ext_file": "201901010601_DC_SG_CTractFIPS_Hourly_Single_GP", "output_dim": 1, "time_intervals": 3600, "init_weight_inf_or_zero": "zero", "set_weight_link_or_dist": "dist", "calculate_weight_adj": false, "weight_adj_epsilon": 0.1}} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gensim==4.2.0 2 | geopandas==0.8.1 3 | hyperopt==0.2.7 4 | matplotlib==3.3.4 5 | networkx==2.5.1 6 | numpy==1.19.5 7 | pandas==1.1.5 8 | ray==1.7.1 9 | scikit_learn==1.1.2 10 | scipy==1.5.0 11 | seaborn==0.11.2 12 | torch==1.12.1 13 | ~orch==1.6.0 14 | -------------------------------------------------------------------------------- /result_convert.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import glob 5 | import os 6 | import matplotlib as mpl 7 | from libcity.model import loss 8 | from sklearn.metrics import r2_score, explained_variance_score 9 | import datetime 10 | import random 11 | 12 | random.seed(10) 13 | 14 | pd.options.mode.chained_assignment = None 15 | results_path = r'D:\ST_Graph\results_record\\' 16 | 17 | 18 | # Give a dir and read all files inside the dir 19 | def get_gp_data(filenames): 20 | filenames = [ec for ec in filenames if 'log' not in ec] 21 | all_results = pd.DataFrame() 22 | for ec in filenames: 23 | nec = glob.glob(ec + '\\evaluate_cache\\*.csv') 24 | model_name = glob.glob(ec + '\\model_cache\\*.m') 25 | if len(nec) > 0: 26 | fec = pd.read_csv(nec[0]) 27 | fec['Model_name'] = model_name[0].split('\\')[-1].split('_')[0] 28 | fec['Model_time'] = datetime.datetime.fromtimestamp(os.path.getmtime(nec[0])) 29 | all_results = all_results.append(fec) 30 | all_results = all_results.reset_index() 31 | return all_results 32 | 33 | 34 | def transfer_gp_data(filenames, ct_visit_mstd, s_small=10): 35 | m_m = [] 36 | for kk in filenames: 37 | print(kk) 38 | filename = glob.glob(kk + r"\\evaluate_cache\*.npz") 39 | model_name = glob.glob(kk + '\\model_cache\\*.m') 40 | if len(model_name) > 0: 41 | model_name = model_name[0].split('\\')[-1].split('_')[0] 42 | print(model_name) 43 | Predict_R = np.load(filename[0]) 44 | # drop the last batch 45 | pred = Predict_R['prediction'][:-16, :, :, :] 46 | truth = Predict_R['truth'][:-16, :, :, :] 47 | sh = pred.shape 48 | print(sh) # no of batches, output_window, no of nodes, output dim 49 | ct_ma = np.tile(ct_visit_mstd[['All_m']].values, (sh[0], sh[1], 1, sh[3])) 50 | ct_sa = np.tile(ct_visit_mstd[['All_std']].values, (sh[0], sh[1], 1, sh[3])) 51 | ct_id = np.tile(ct_visit_mstd[[sunit]].values, (sh[0], sh[1], 1, sh[3])) 52 | ahead_step = np.tile(np.expand_dims(np.array(range(0, sh[1])), axis=(1, 2)), (sh[0], 1, sh[2], sh[3])) 53 | P_R = pd.DataFrame({'prediction': pred.flatten(), 'truth': truth.flatten(), 54 | 'All_m': ct_ma.flatten(), 'All_std': ct_sa.flatten(), sunit: ct_id.flatten(), 55 | 'ahead_step': ahead_step.flatten()}) 56 | P_R['prediction_t'] = P_R['prediction'] * P_R['All_std'] + P_R['All_m'] 57 | P_R['truth_t'] = P_R['truth'] * P_R['All_std'] + P_R['All_m'] 58 | P_R.loc[P_R['prediction_t'] < 0, 'prediction_t'] = 0 59 | 60 | # not consider small volume 61 | for rr in range(0, sh[1]): 62 | pr = P_R.loc[(P_R['ahead_step'] == rr) & (P_R['truth_t'] > s_small), 'prediction_t'] 63 | tr = P_R.loc[(P_R['ahead_step'] == rr) & (P_R['truth_t'] > s_small), 'truth_t'] 64 | m_m.append([model_name, rr, datetime.datetime.fromtimestamp(os.path.getmtime(filename[0])), 65 | loss.masked_mae_np(pr, tr), loss.masked_mse_np(pr, tr), loss.masked_rmse_np(pr, tr), 66 | r2_score(tr, pr), explained_variance_score(tr, pr), loss.masked_mape_np(pr, tr)]) 67 | else: 68 | print(kk + '----NULL----') 69 | return m_m 70 | 71 | 72 | ############ Read metrics of multiple models ############ 73 | time_sps, n_steps, nfold = ['201901010601_BM', '201901010601_DC'], [3, 6, 12, 24], 'Final' 74 | for time_sp in time_sps: 75 | for n_step in n_steps: 76 | # time_sp = '201901010601_BM' 77 | sunit = 'CTractFIPS' 78 | filenames = glob.glob(results_path + r"%s steps\%s\%s\*" % (n_step, nfold, time_sp)) 79 | all_results = get_gp_data(filenames) 80 | if len(all_results) > 0: 81 | all_results_avg = all_results.groupby(['Model_name']).mean().sort_values(by='MAE').reset_index() 82 | # all_results_avg = all_results_avg[~all_results_avg['Model_name'].isin(['STSGCN', 'STTN', 'Seq2Seq'])] 83 | all_results_avg = all_results_avg.sort_values(by='MAE').reset_index() 84 | n_col = all_results_avg.select_dtypes('number').columns 85 | all_results_avg.to_csv( 86 | r"D:\ST_Graph\Results\final\M_%s_gp_%s_steps_%s_%s.csv" % (nfold, n_step, sunit, time_sp)) 87 | 88 | # Re-transform the data 89 | ct_visit_mstd = pd.read_pickle(r'.\other_data\%s_%s_visit_mstd.pkl' % (sunit, time_sp)).sort_values( 90 | by=sunit).reset_index(drop=True) 91 | m_m = transfer_gp_data(filenames, ct_visit_mstd) 92 | m_md = pd.DataFrame(m_m) 93 | m_md.columns = ['Model_name', 'index', 'Model_time', 'MAE', 'MSE', 'RMSE', 'R2', 'EVAR', 'MAPE'] 94 | avg_t = m_md.groupby(['Model_name']).mean().sort_values(by='MAE').reset_index() 95 | avg_t = avg_t[~avg_t['Model_name'].isin(['STSGCN', 'STTN', 'Seq2Seq', 'TGCN'])] 96 | avg_t.to_csv( 97 | r"D:\ST_Graph\Results\final\M_%s_truth_%s_steps_%s_%s.csv" % (nfold, n_step, sunit, time_sp)) 98 | 99 | # Baseline comparison 100 | # Read metrics for each model and format the table 101 | time_sps, n_steps, nfold, sunit = ['201901010601_BM', '201901010601_DC'], [3, 6, 12, 24], 'Final', 'CTractFIPS' 102 | All_metrics = pd.DataFrame() 103 | for time_sp in time_sps: 104 | for n_step in n_steps: 105 | avg_t = pd.read_csv(r"D:\ST_Graph\Results\final\M_%s_truth_%s_steps_%s_%s.csv" % ( 106 | nfold, n_step, sunit, time_sp), index_col=0) 107 | avg_t['Step_'] = n_step 108 | avg_t['data'] = time_sp 109 | All_metrics = All_metrics.append(avg_t[['Model_name', 'MAE', 'RMSE', 'R2', 'MAPE', 'Step_', 'data']]) 110 | 111 | All_metrics_base = All_metrics[All_metrics['Model_name'] == 'MultiATGCN'] 112 | All_metrics_base.columns = ['B_Model_name', 'B_MAE', 'B_RMSE', 'B_R2', 'B_MAPE', 'Step_', 'data'] 113 | All_metrics = All_metrics.merge(All_metrics_base, on=['Step_', 'data']) 114 | for kk in ['MAE', 'RMSE', 'R2', 'MAPE']: 115 | All_metrics['Pct_' + kk] = 100 * (All_metrics[kk] - All_metrics['B_' + kk]) / All_metrics[kk] 116 | All_metrics[kk] = All_metrics[kk].round(3).map('{:.2f}'.format).astype(str) + ' (' + \ 117 | All_metrics['Pct_' + kk].round(3).map('{:.1f}'.format).astype(str) + '%)' 118 | All_metrics = All_metrics.sort_values(by=['data', 'Step_', 'Pct_MAE'], ascending=[True, True, False]) 119 | All_metrics = All_metrics[~All_metrics['Model_name'].isin(['Seq2Seq'])] 120 | All_metrics = All_metrics[['Model_name', 'Step_', 'data', 'MAE', 'RMSE', 'R2', 'MAPE']] 121 | All_metrics_f = All_metrics.pivot(index=['Step_', 'Model_name'], columns=['data'], 122 | values=['MAE', 'RMSE', 'MAPE']).reset_index() 123 | All_metrics_f['Sort'] = All_metrics_f['MAE']['201901010601_BM'].str.split(' ', 1, expand=True)[0].astype(float) 124 | All_metrics_f = All_metrics_f.sort_values(by=['Step_', 'Sort'], ascending=[True, False]) 125 | idx = pd.IndexSlice 126 | pd.concat([All_metrics_f[['Model_name', 'Step_']], All_metrics_f.loc[:, idx[:, '201901010601_BM']], 127 | All_metrics_f.loc[:, idx[:, '201901010601_DC']]], axis=1).to_csv( 128 | r'D:\ST_Graph\Results\All_M_metrics_format.csv', index=0) 129 | 130 | ########### Read metrics of multiple parameters 131 | para_list = [''.join(str(x)) for x in 132 | [['od', 'bidirection'], ['od', 'unidirection'], ['od', 'none'], ['dist', 'none'], ['cosine', 'none'], 133 | ['identity', 'none'], ['multi', 'bidirection']]] 134 | # para_list = [True, False] 135 | # para_list = [16, 32, 64, 72] 136 | time_sps, n_repeat, para_name, n_steps, sunit = ['201901010601_BM'], 4, 'P_graph_new', 24, 'CTractFIPS' 137 | for time_sp in time_sps: 138 | filenames = glob.glob(results_path + r"%s steps\%s\%s\*" % (n_steps, para_name, time_sp)) 139 | all_results = get_gp_data(filenames) 140 | all_results = all_results.sort_values(by=['Model_time', 'index']).reset_index(drop=True) 141 | all_results['Para'] = np.repeat(para_list, n_steps * n_repeat) 142 | all_results_avg = all_results.groupby(['Para']).mean().sort_values(by='MAE').reset_index() 143 | all_results_avg.to_csv(r"D:\ST_Graph\Results\results_%s_gp_%s_%s.csv" % (para_name, sunit, time_sp)) 144 | 145 | # Re-transform the data 146 | ct_visit_mstd = pd.read_pickle(r'D:\ST_Graph\Results\%s_%s_visit_mstd.pkl' % (sunit, time_sp)) 147 | ct_visit_mstd = ct_visit_mstd.sort_values(by=sunit).reset_index(drop=True) 148 | # Read prediction result 149 | m_m = transfer_gp_data(filenames, ct_visit_mstd, s_small=10) 150 | m_md = pd.DataFrame(m_m) 151 | m_md.columns = ['Model_name', 'index', 'Model_time', 'MAE', 'MSE', 'RMSE', 'R2', 'EVAR', 'MAPE'] 152 | m_md = m_md.sort_values(by=['Model_time', 'index']).reset_index(drop=True) 153 | m_md['Para'] = np.repeat(para_list, n_steps * n_repeat) 154 | avg_t = m_md.groupby(['Para'])[['MAE', 'RMSE', 'MAPE']].mean().reset_index() 155 | avg_t.columns = ['Para', 'MAE_mean', 'RMSE_mean', 'MAPE_mean'] 156 | avg_std = m_md.groupby(['Para'])[['MAE', 'RMSE', 'MAPE']].std().reset_index() 157 | avg_std.columns = ['Para', 'MAE_std', 'RMSE_std', 'MAPE_std'] 158 | avg_t = avg_t.merge(avg_std, on=['Para']).sort_values(by='MAE_mean') 159 | avg_t = avg_t[['Para', 'MAE_mean', 'MAE_std', 'RMSE_mean', 'RMSE_std', 'MAPE_mean', 'MAPE_std']] 160 | avg_t.to_csv(r"D:\ST_Graph\Results\results_mstd_%s_truth_%s_%s.csv" % (para_name, sunit, time_sp)) 161 | -------------------------------------------------------------------------------- /run_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from libcity.pipeline import run_model 3 | from libcity.utils import str2bool, add_general_args 4 | 5 | # # Dataset: COVID01010401_SG_CTractFIPS_Hourly_Single_GP SG_CTractFIPS_Hourly_Single_GP 6 | # model_list = ['MultiATGCN', 'AGCRN', 'ASTGCN', 'STGCN', 'MTGNN', 'GWNET', 'GMAN', 'STTN', "GRU", 'LSTM', 7 | # 'RNN', 'Seq2Seq', 'FNN', 'TGCN', 'DCRNN'] 8 | model_list = ['MultiATGCN'] 9 | if __name__ == '__main__': 10 | for model_name in model_list: 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--task', type=str, default='traffic_state_pred', help='the name of task') 13 | parser.add_argument('--model', type=str, default=model_name, help='the name of model') 14 | parser.add_argument('--dataset', type=str, default='201901010601_DC_SG_CTractFIPS_Hourly_Single_GP', 15 | help='the name of dataset') 16 | parser.add_argument('--config_file', type=str, default='config_user', help='the file name of config file') 17 | parser.add_argument('--saved_model', type=str2bool, default=True, help='whether save the trained model') 18 | parser.add_argument('--train', type=str2bool, default=True, help='whether re-train if the model is trained') 19 | parser.add_argument('--exp_id', type=str, default=None, help='id of experiment') 20 | parser.add_argument('--seed', type=int, default=100, help='random seed') 21 | parser.add_argument('--start_dim', type=int, default=0, help='start_dim') 22 | parser.add_argument('--end_dim', type=int, default=1, help='end_dim') 23 | add_general_args(parser) 24 | args, unknown = parser.parse_known_args() 25 | dict_args = vars(args) 26 | other_args = {key: val for key, val in dict_args.items() if key not in 27 | ['task', 'model', 'dataset', 'config_file', 'saved_model', 'train'] and val is not None} 28 | run_model(task=args.task, model_name=args.model, dataset_name=args.dataset, config_file=args.config_file, 29 | saved_model=args.saved_model, train=args.train, other_args=other_args) 30 | -------------------------------------------------------------------------------- /run_model_parameter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from libcity.pipeline import run_model 3 | from libcity.utils import str2bool, add_general_args 4 | 5 | model_list = ['MultiATGCN'] 6 | # para_list = [[1, 0, 0], [0, 0, 1], [0, 1, 1], [1, 0, 1], [1, 1, 1], [2, 1, 1], [3, 1, 1], [1, 2, 1], [1, 3, 1], 7 | # [2, 2, 1], [2, 3, 1], [3, 3, 1], [1, 1, 0], [1, 1, 2], [1, 1, 3]] 8 | # para_list = [['od', 'bidirection'], ['od', 'unidirection'], ['od', 'none'], ['dist', 'none'], ['cosine', 'none'], 9 | # ['identity', 'none'], ['multi', 'bidirection']] 10 | # para_list = [[True, True, True, True], [True, True, False, False], [False, True, False, False], 11 | # [False, True, False, True], [False, False, False, False]] 12 | # para_list = [True, False] 13 | # para_list= [16, 32, 64, 72] 14 | # para_list = [1, 5, 10, 20, 30, 50] 15 | para_list = [1, 2, 3] 16 | # para_list = [False] 17 | if __name__ == '__main__': 18 | for model_name in model_list: 19 | for para in para_list: 20 | for random_seed in [0, 10, 100, 1000]: 21 | for dataset in ['201901010601_BM_SG_CTractFIPS_Hourly_Single_GP']: 22 | print(para) 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--task', type=str, default='traffic_state_pred', help='the name of task') 25 | parser.add_argument('--model', type=str, default=model_name, help='the name of model') 26 | parser.add_argument('--dataset', type=str, default=dataset, help='the name of dataset') 27 | parser.add_argument('--config_file', type=str, default='config_user', help='config file') 28 | parser.add_argument('--saved_model', type=str2bool, default=True, help='saved_model') 29 | parser.add_argument('--train', type=str2bool, default=True, 30 | help='whether re-train if the model is trained') 31 | parser.add_argument('--exp_id', type=str, default=None, help='id of experiment') 32 | parser.add_argument('--seed', type=int, default=random_seed, help='random seed') 33 | parser.add_argument('--start_dim', type=int, default=0, help='start_dim') 34 | parser.add_argument('--end_dim', type=int, default=1, help='end_dim') 35 | # parser.add_argument('--embed_dim_node', type=int, default=para, help='embed_dim_node') 36 | parser.add_argument('--cheb_order', type=int, default=para, help='cheb_order') 37 | # parser.add_argument('--rnn_units', type=int, default=para, help='rnn_units') 38 | # parser.add_argument('--adjtype', type=str, default=para[0], help='adjtype') 39 | # parser.add_argument('--adpadj', type=str, default=para[1], help='adpadj') 40 | # parser.add_argument('--node_specific_off', type=bool, default=para, help='node_specific_off') 41 | # parser.add_argument('--gcn_off', type=bool, default=para, help='gcn_off') 42 | # parser.add_argument('--fnn_off', type=bool, default=para, help='fnn_off') 43 | # parser.add_argument('--load_dynamic', type=bool, default=para[0], help='load_dynamic') 44 | # parser.add_argument('--add_time_in_day', type=bool, default=para[1], help='add_time_in_day') 45 | # parser.add_argument('--add_day_in_week', type=bool, default=para[2], help='add_day_in_week') 46 | # parser.add_argument('--add_static', type=bool, default=para[3], help='add_static') 47 | # parser.add_argument('--len_closeness', type=int, default=para[0], help='len_closeness') 48 | # parser.add_argument('--len_period', type=int, default=para[1], help='len_period') 49 | # parser.add_argument('--len_trend', type=int, default=para[2], help='len_trend') 50 | add_general_args(parser) 51 | # args = parser.parse_args() 52 | args, unknown = parser.parse_known_args() 53 | dict_args = vars(args) 54 | other_args = {key: val for key, val in dict_args.items() if key not in 55 | ['task', 'model', 'dataset', 'config_file', 'saved_model', 56 | 'train'] and val is not None} 57 | run_model(task=args.task, model_name=args.model, dataset_name=args.dataset, 58 | config_file=args.config_file, saved_model=args.saved_model, train=args.train, 59 | other_args=other_args) 60 | --------------------------------------------------------------------------------